From e5ac6dd432ec4c16c3d2801e8ef0c1f125380794 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 28 Aug 2025 14:54:07 +0200 Subject: [PATCH 01/35] WIP make provider events a proper irpc protocol and allow configuring notifications/requests for each event type. --- Cargo.lock | 4 - Cargo.toml | 2 +- src/provider.rs | 5 +- src/provider/event_proto.rs | 290 ++++++++++++++++++++++++++++++++++++ 4 files changed, 294 insertions(+), 7 deletions(-) create mode 100644 src/provider/event_proto.rs diff --git a/Cargo.lock b/Cargo.lock index 4068354f7..1a4de777e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1943,8 +1943,6 @@ dependencies = [ [[package]] name = "irpc" version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9f8f1d0987ea9da3d74698f921d0a817a214c83b2635a33ed4bc3efa4de1acd" dependencies = [ "anyhow", "futures-buffered", @@ -1966,8 +1964,6 @@ dependencies = [ [[package]] name = "irpc-derive" version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e0b26b834d401a046dd9d47bc236517c746eddbb5d25ff3e1a6075bfa4eebdb" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index bcd5f42d0..3a642632c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ self_cell = "1.1.0" genawaiter = { version = "0.99.1", features = ["futures03"] } iroh-base = "0.91.1" reflink-copy = "0.1.24" -irpc = { version = "0.7.0", features = ["rpc", "quinn_endpoint_setup", "spans", "stream", "derive"], default-features = false } +irpc = { version = "0.7.0", features = ["rpc", "quinn_endpoint_setup", "spans", "stream", "derive"], default-features = false, path = "../irpc" } iroh-metrics = { version = "0.35" } [dev-dependencies] diff --git a/src/provider.rs b/src/provider.rs index 61af8f6e1..141d674c6 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -20,7 +20,7 @@ use iroh::{ }; use irpc::channel::oneshot; use n0_future::StreamExt; -use serde::de::DeserializeOwned; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::{io::AsyncRead, select, sync::mpsc}; use tracing::{debug, debug_span, error, warn, Instrument}; @@ -33,6 +33,7 @@ use crate::{ }, Hash, }; +mod event_proto; /// Provider progress events, to keep track of what the provider is doing. /// @@ -129,7 +130,7 @@ pub enum Event { } /// Statistics about a successful or failed transfer. -#[derive(Debug)] +#[derive(Debug, Serialize, Deserialize)] pub struct TransferStats { /// The number of bytes sent that are part of the payload. pub payload_bytes_sent: u64, diff --git a/src/provider/event_proto.rs b/src/provider/event_proto.rs new file mode 100644 index 000000000..cc6e1aeab --- /dev/null +++ b/src/provider/event_proto.rs @@ -0,0 +1,290 @@ +use std::fmt::Debug; + +use iroh::NodeId; +use irpc::{ + channel::{none::NoSender, oneshot}, + rpc_requests, +}; +use serde::{Deserialize, Serialize}; +use snafu::Snafu; + +use crate::{protocol::ChunkRangesSeq, provider::TransferStats, Hash}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum EventMode { + /// We don't get these kinds of events at all + #[default] + None, + /// We get a notification for these kinds of events + Notify, + /// We can respond to these kinds of events, either by aborting or by + /// e.g. introducing a delay for throttling. + Request, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum EventMode2 { + /// We don't get these kinds of events at all + #[default] + None, + /// We get a notification for these kinds of events + Notify, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum AbortReason { + RateLimited, + Permission, +} + +#[derive(Debug, Snafu)] +pub enum ClientError { + RateLimited, + Permission, + #[snafu(transparent)] + Irpc { + source: irpc::Error, + }, +} + +impl From for ClientError { + fn from(value: AbortReason) -> Self { + match value { + AbortReason::RateLimited => ClientError::RateLimited, + AbortReason::Permission => ClientError::Permission, + } + } +} + +pub type EventResult = Result<(), AbortReason>; +pub type ClientResult = Result<(), ClientError>; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub struct EventMask { + connected: EventMode, + get: EventMode, + get_many: EventMode, + push: EventMode, + transfer: EventMode, + transfer_complete: EventMode2, + transfer_aborted: EventMode2, +} + +/// Newtype wrapper that wraps an event so that it is a distinct type for the notify variant. +#[derive(Debug, Serialize, Deserialize)] +pub struct Notify(T); + +#[derive(Debug, Default)] +pub struct Client { + mask: EventMask, + inner: Option>, +} + +/// A new get request was received from the provider. +#[derive(Debug, Serialize, Deserialize)] +pub struct GetRequestReceived { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The root hash of the request. + pub hash: Hash, + /// The exact query ranges of the request. + pub ranges: ChunkRangesSeq, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct GetManyRequestReceived { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The root hash of the request. + pub hashes: Vec, + /// The exact query ranges of the request. + pub ranges: ChunkRangesSeq, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct PushRequestReceived { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The root hash of the request. + pub hash: Hash, + /// The exact query ranges of the request. + pub ranges: ChunkRangesSeq, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TransferProgress { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The index of the blob in the request. 0 for the first blob or for raw blob requests. + pub index: u64, + /// The end offset of the chunk that was sent. + pub end_offset: u64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TransferStarted { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The index of the blob in the request. 0 for the first blob or for raw blob requests. + pub index: u64, + /// The hash of the blob. This is the hash of the request for the first blob, the child hash (index-1) for subsequent blobs. + pub hash: Hash, + /// The size of the blob. This is the full size of the blob, not the size we are sending. + pub size: u64, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TransferCompleted { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// Statistics about the transfer. + pub stats: Box, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TransferAborted { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// Statistics about the part of the transfer that was aborted. + pub stats: Option>, +} + +/// Client for progress notifications. +/// +/// For most event types, the client can be configured to either send notifications or requests that +/// can have a response. +impl Client { + /// A client that does not send anything. + pub const NONE: Self = Self { + mask: EventMask { + connected: EventMode::None, + get: EventMode::None, + get_many: EventMode::None, + push: EventMode::None, + transfer: EventMode::None, + transfer_complete: EventMode2::None, + transfer_aborted: EventMode2::None, + }, + inner: None, + }; + + pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult { + Ok(if let Some(client) = &self.inner { + match self.mask.connected { + EventMode::None => {} + EventMode::Notify => client.notify(Notify(f())).await?, + EventMode::Request => client.rpc(f()).await??, + } + }) + } + + pub async fn get_request(&self, f: impl Fn() -> GetRequestReceived) -> ClientResult { + Ok(if let Some(client) = &self.inner { + match self.mask.get { + EventMode::None => {} + EventMode::Notify => client.notify(Notify(f())).await?, + EventMode::Request => client.rpc(f()).await??, + } + }) + } + + pub async fn push_request(&self, f: impl Fn() -> PushRequestReceived) -> ClientResult { + Ok(if let Some(client) = &self.inner { + match self.mask.push { + EventMode::None => {} + EventMode::Notify => client.notify(Notify(f())).await?, + EventMode::Request => client.rpc(f()).await??, + } + }) + } + + pub async fn send_get_many_request( + &self, + f: impl Fn() -> GetManyRequestReceived, + ) -> ClientResult { + Ok(if let Some(client) = &self.inner { + match self.mask.get_many { + EventMode::None => {} + EventMode::Notify => client.notify(Notify(f())).await?, + EventMode::Request => client.rpc(f()).await??, + } + }) + } + + pub async fn transfer_progress(&self, f: impl Fn() -> TransferProgress) -> ClientResult { + Ok(if let Some(client) = &self.inner { + match self.mask.transfer { + EventMode::None => {} + EventMode::Notify => client.notify(Notify(f())).await?, + EventMode::Request => client.rpc(f()).await??, + } + }) + } +} + +#[rpc_requests(message = ProviderMessage)] +#[derive(Debug, Serialize, Deserialize)] +pub enum ProviderProto { + /// A new client connected to the provider. + #[rpc(tx = oneshot::Sender)] + #[wrap(ClientConnected)] + ClientConnected { connection_id: u64, node_id: NodeId }, + /// A new client connected to the provider. Notify variant. + #[rpc(tx = NoSender)] + ClientConnectedNotify(Notify), + /// A client disconnected from the provider. + #[rpc(tx = NoSender)] + #[wrap(ConnectionClosed)] + ConnectionClosed { connection_id: u64 }, + + #[rpc(tx = oneshot::Sender)] + /// A new get request was received from the provider. + GetRequestReceived(GetRequestReceived), + + #[rpc(tx = NoSender)] + /// A new get request was received from the provider. + GetRequestReceivedNotify(Notify), + /// A new get request was received from the provider. + #[rpc(tx = oneshot::Sender)] + GetManyRequestReceived(GetManyRequestReceived), + /// A new get request was received from the provider. + #[rpc(tx = NoSender)] + GetManyRequestReceivedNotify(Notify), + /// A new get request was received from the provider. + #[rpc(tx = oneshot::Sender)] + PushRequestReceived(PushRequestReceived), + /// A new get request was received from the provider. + #[rpc(tx = NoSender)] + PushRequestReceivedNotify(Notify), + /// Transfer for the nth blob started. + #[rpc(tx = NoSender)] + TransferStarted(TransferStarted), + /// Progress of the transfer. + #[rpc(tx = oneshot::Sender)] + TransferProgress(TransferProgress), + /// Progress of the transfer. + #[rpc(tx = NoSender)] + TransferProgressNotify(Notify), + /// Entire transfer completed. + #[rpc(tx = NoSender)] + TransferCompleted(TransferCompleted), + /// Entire transfer aborted. + #[rpc(tx = NoSender)] + TransferAborted(TransferAborted), +} From d17c6f6a4a246dac165b951800597340cb2bacc7 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 28 Aug 2025 15:03:26 +0200 Subject: [PATCH 02/35] Add transfer_completed and transfer_aborted fn. --- src/provider/event_proto.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/provider/event_proto.rs b/src/provider/event_proto.rs index cc6e1aeab..ac769a476 100644 --- a/src/provider/event_proto.rs +++ b/src/provider/event_proto.rs @@ -236,6 +236,24 @@ impl Client { } }) } + + pub async fn transfer_completed(&self, f: impl Fn() -> TransferCompleted) -> ClientResult { + Ok(if let Some(client) = &self.inner { + match self.mask.transfer_complete { + EventMode2::Notify => client.notify(f()).await?, + EventMode2::None => {} + } + }) + } + + pub async fn transfer_aborted(&self, f: impl Fn() -> TransferAborted) -> ClientResult { + Ok(if let Some(client) = &self.inner { + match self.mask.transfer_aborted { + EventMode2::Notify => client.notify(f()).await?, + EventMode2::None => {} + } + }) + } } #[rpc_requests(message = ProviderMessage)] From b23995e71b479ccba05370a9df8dbb6099259a0d Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 29 Aug 2025 13:23:35 +0200 Subject: [PATCH 03/35] Nicer proto --- src/provider/event_proto.rs | 620 ++++++++++++++++++++++++++---------- 1 file changed, 444 insertions(+), 176 deletions(-) diff --git a/src/provider/event_proto.rs b/src/provider/event_proto.rs index ac769a476..8713d8aac 100644 --- a/src/provider/event_proto.rs +++ b/src/provider/event_proto.rs @@ -1,36 +1,51 @@ use std::fmt::Debug; -use iroh::NodeId; use irpc::{ - channel::{none::NoSender, oneshot}, - rpc_requests, + channel::{mpsc, none::NoSender, oneshot}, + rpc_requests, Channels, WithChannels, }; use serde::{Deserialize, Serialize}; use snafu::Snafu; -use crate::{protocol::ChunkRangesSeq, provider::TransferStats, Hash}; +use crate::provider::{event_proto::irpc_ext::IrpcClientExt, TransferStats}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[repr(u8)] -pub enum EventMode { - /// We don't get these kinds of events at all +pub enum ConnectMode { + /// We don't get notification of connect events at all. #[default] None, - /// We get a notification for these kinds of events + /// We get a notification for connect events. Notify, - /// We can respond to these kinds of events, either by aborting or by - /// e.g. introducing a delay for throttling. + /// We get a request for connect events and can reject incoming connections. Request, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[repr(u8)] -pub enum EventMode2 { - /// We don't get these kinds of events at all +pub enum RequestMode { + /// We don't get request events at all. #[default] None, - /// We get a notification for these kinds of events + /// We get a notification for each request. Notify, + /// We get a request for each request, and can reject incoming requests. + Request, + /// We get a notification for each request as well as detailed transfer events. + NotifyLog, + /// We get a request for each request, and can reject incoming requests. + /// We also get detailed transfer events. + RequestLog, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum ThrottleMode { + /// We don't get these kinds of events at all + #[default] + None, + /// We call throttle to give the event handler a way to throttle requests + Throttle, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] @@ -58,111 +73,154 @@ impl From for ClientError { } } +impl From for ClientError { + fn from(value: irpc::channel::RecvError) -> Self { + ClientError::Irpc { + source: value.into(), + } + } +} + +impl From for ClientError { + fn from(value: irpc::channel::SendError) -> Self { + ClientError::Irpc { + source: value.into(), + } + } +} + pub type EventResult = Result<(), AbortReason>; pub type ClientResult = Result<(), ClientError>; #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub struct EventMask { - connected: EventMode, - get: EventMode, - get_many: EventMode, - push: EventMode, - transfer: EventMode, - transfer_complete: EventMode2, - transfer_aborted: EventMode2, + connected: ConnectMode, + get: RequestMode, + get_many: RequestMode, + push: RequestMode, + /// throttling is somewhat costly, so you can disable it completely + throttle: ThrottleMode, +} + +impl EventMask { + /// Everything is disabled. You won't get any events, but there is also no runtime cost. + pub const NONE: Self = Self { + connected: ConnectMode::None, + get: RequestMode::None, + get_many: RequestMode::None, + push: RequestMode::None, + throttle: ThrottleMode::None, + }; + + /// You get asked for every single thing that is going on and can intervene/throttle. + pub const ALL: Self = Self { + connected: ConnectMode::Request, + get: RequestMode::RequestLog, + get_many: RequestMode::RequestLog, + push: RequestMode::RequestLog, + throttle: ThrottleMode::Throttle, + }; + + /// You get notified for every single thing that is going on, but can't intervene. + pub const NOTIFY_ALL: Self = Self { + connected: ConnectMode::Notify, + get: RequestMode::NotifyLog, + get_many: RequestMode::NotifyLog, + push: RequestMode::NotifyLog, + throttle: ThrottleMode::None, + }; } /// Newtype wrapper that wraps an event so that it is a distinct type for the notify variant. #[derive(Debug, Serialize, Deserialize)] pub struct Notify(T); -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct Client { mask: EventMask, inner: Option>, } -/// A new get request was received from the provider. -#[derive(Debug, Serialize, Deserialize)] -pub struct GetRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// The root hash of the request. - pub hash: Hash, - /// The exact query ranges of the request. - pub ranges: ChunkRangesSeq, +#[derive(Debug, Default)] +enum RequestUpdates { + /// Request tracking was not configured, all ops are no-ops + #[default] + None, + /// Active request tracking, all ops actually send + Active(mpsc::Sender), + /// Disabled request tracking, we just hold on to the sender so it drops + /// once the request is completed or aborted. + Disabled(mpsc::Sender), } -#[derive(Debug, Serialize, Deserialize)] -pub struct GetManyRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// The root hash of the request. - pub hashes: Vec, - /// The exact query ranges of the request. - pub ranges: ChunkRangesSeq, +pub struct RequestTracker { + updates: RequestUpdates, + throttle: Option<(irpc::Client, u64, u64)>, } -#[derive(Debug, Serialize, Deserialize)] -pub struct PushRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// The root hash of the request. - pub hash: Hash, - /// The exact query ranges of the request. - pub ranges: ChunkRangesSeq, -} +impl RequestTracker { + fn new( + updates: RequestUpdates, + throttle: Option<(irpc::Client, u64, u64)>, + ) -> Self { + Self { updates, throttle } + } -#[derive(Debug, Serialize, Deserialize)] -pub struct TransferProgress { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// The index of the blob in the request. 0 for the first blob or for raw blob requests. - pub index: u64, - /// The end offset of the chunk that was sent. - pub end_offset: u64, -} + /// A request tracker that doesn't track anything. + const NONE: Self = Self { + updates: RequestUpdates::None, + throttle: None, + }; -#[derive(Debug, Serialize, Deserialize)] -pub struct TransferStarted { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// The index of the blob in the request. 0 for the first blob or for raw blob requests. - pub index: u64, - /// The hash of the blob. This is the hash of the request for the first blob, the child hash (index-1) for subsequent blobs. - pub hash: Hash, - /// The size of the blob. This is the full size of the blob, not the size we are sending. - pub size: u64, -} + /// Transfer for index `index` started, size `size` + pub async fn transfer_started(&self, index: u64, size: u64) -> irpc::Result<()> { + if let RequestUpdates::Active(tx) = &self.updates { + tx.send(RequestUpdate::Started(TransferStarted { index, size })) + .await?; + } + Ok(()) + } -#[derive(Debug, Serialize, Deserialize)] -pub struct TransferCompleted { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// Statistics about the transfer. - pub stats: Box, -} + /// Transfer progress for the previously reported blob, end_offset is the new end offset in bytes. + pub async fn transfer_progress(&mut self, end_offset: u64) -> ClientResult { + if let RequestUpdates::Active(tx) = &mut self.updates { + tx.try_send(RequestUpdate::Progress(TransferProgress { end_offset })) + .await?; + } + if let Some((throttle, connection_id, request_id)) = &self.throttle { + throttle + .rpc(Throttle { + connection_id: *connection_id, + request_id: *request_id, + }) + .await??; + } + Ok(()) + } -#[derive(Debug, Serialize, Deserialize)] -pub struct TransferAborted { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// Statistics about the part of the transfer that was aborted. - pub stats: Option>, + /// Transfer completed for the previously reported blob. + pub async fn transfer_completed( + &mut self, + f: impl Fn() -> Box, + ) -> irpc::Result<()> { + if let RequestUpdates::Active(tx) = &self.updates { + tx.send(RequestUpdate::Completed(TransferCompleted { stats: f() })) + .await?; + } + Ok(()) + } + + /// Transfer aborted for the previously reported blob. + pub async fn transfer_aborted( + &mut self, + f: impl Fn() -> Option>, + ) -> irpc::Result<()> { + if let RequestUpdates::Active(tx) = &self.updates { + tx.send(RequestUpdate::Aborted(TransferAborted { stats: f() })) + .await?; + } + Ok(()) + } } /// Client for progress notifications. @@ -172,87 +230,132 @@ pub struct TransferAborted { impl Client { /// A client that does not send anything. pub const NONE: Self = Self { - mask: EventMask { - connected: EventMode::None, - get: EventMode::None, - get_many: EventMode::None, - push: EventMode::None, - transfer: EventMode::None, - transfer_complete: EventMode2::None, - transfer_aborted: EventMode2::None, - }, + mask: EventMask::NONE, inner: None, }; + /// A new client has been connected. pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult { Ok(if let Some(client) = &self.inner { match self.mask.connected { - EventMode::None => {} - EventMode::Notify => client.notify(Notify(f())).await?, - EventMode::Request => client.rpc(f()).await??, + ConnectMode::None => {} + ConnectMode::Notify => client.notify(Notify(f())).await?, + ConnectMode::Request => client.rpc(f()).await??, } }) } - pub async fn get_request(&self, f: impl Fn() -> GetRequestReceived) -> ClientResult { - Ok(if let Some(client) = &self.inner { - match self.mask.get { - EventMode::None => {} - EventMode::Notify => client.notify(Notify(f())).await?, - EventMode::Request => client.rpc(f()).await??, - } - }) - } - - pub async fn push_request(&self, f: impl Fn() -> PushRequestReceived) -> ClientResult { - Ok(if let Some(client) = &self.inner { - match self.mask.push { - EventMode::None => {} - EventMode::Notify => client.notify(Notify(f())).await?, - EventMode::Request => client.rpc(f()).await??, - } - }) + /// Start a get request. You will get back either an error if the request should not proceed, or a + /// [`RequestTracker`] that you can use to log progress for this particular request. + /// + /// Depending on the event sender config, the returned tracker might be a no-op. + pub async fn get_request( + &self, + f: impl FnOnce() -> GetRequestReceived, + ) -> Result { + self.request(f).await } - pub async fn send_get_many_request( + // Start a get_many request. You will get back either an error if the request should not proceed, or a + /// [`RequestTracker`] that you can use to log progress for this particular request. + /// + /// Depending on the event sender config, the returned tracker might be a no-op. + pub async fn get_many_request( &self, - f: impl Fn() -> GetManyRequestReceived, - ) -> ClientResult { - Ok(if let Some(client) = &self.inner { - match self.mask.get_many { - EventMode::None => {} - EventMode::Notify => client.notify(Notify(f())).await?, - EventMode::Request => client.rpc(f()).await??, - } - }) + f: impl FnOnce() -> GetManyRequestReceived, + ) -> Result { + self.request(f).await } - pub async fn transfer_progress(&self, f: impl Fn() -> TransferProgress) -> ClientResult { - Ok(if let Some(client) = &self.inner { - match self.mask.transfer { - EventMode::None => {} - EventMode::Notify => client.notify(Notify(f())).await?, - EventMode::Request => client.rpc(f()).await??, - } - }) + // Start a push request. You will get back either an error if the request should not proceed, or a + /// [`RequestTracker`] that you can use to log progress for this particular request. + /// + /// Depending on the event sender config, the returned tracker might be a no-op. + pub async fn push_request( + &self, + f: impl FnOnce() -> PushRequestReceived, + ) -> Result { + self.request(f).await } - pub async fn transfer_completed(&self, f: impl Fn() -> TransferCompleted) -> ClientResult { - Ok(if let Some(client) = &self.inner { - match self.mask.transfer_complete { - EventMode2::Notify => client.notify(f()).await?, - EventMode2::None => {} + /// Abstract request, to DRY the 3 to 4 request types. + /// + /// DRYing stuff with lots of bounds is no fun at all... + async fn request(&self, f: impl FnOnce() -> Req) -> Result + where + Req: Request, + ProviderProto: From, + ProviderMessage: From>, + Req: Channels< + ProviderProto, + Tx = oneshot::Sender, + Rx = mpsc::Receiver, + >, + ProviderProto: From>, + ProviderMessage: From, ProviderProto>>, + Notify: Channels>, + { + Ok(self.into_tracker(if let Some(client) = &self.inner { + match self.mask.get { + RequestMode::None => { + if self.mask.throttle == ThrottleMode::Throttle { + // if throttling is enabled, we need to call f to get connection_id and request_id + let msg = f(); + (RequestUpdates::None, msg.id()) + } else { + (RequestUpdates::None, (0, 0)) + } + } + RequestMode::Notify => { + let msg = f(); + let id = msg.id(); + ( + RequestUpdates::Disabled(client.notify_streaming(Notify(msg), 32).await?), + id, + ) + } + RequestMode::Request => { + let msg = f(); + let id = msg.id(); + let (tx, rx) = client.client_streaming(msg, 32).await?; + // bail out if the request is not allowed + rx.await??; + (RequestUpdates::Disabled(tx), id) + } + RequestMode::NotifyLog => { + let msg = f(); + let id = msg.id(); + ( + RequestUpdates::Active(client.notify_streaming(Notify(msg), 32).await?), + id, + ) + } + RequestMode::RequestLog => { + let msg = f(); + let id = msg.id(); + let (tx, rx) = client.client_streaming(msg, 32).await?; + // bail out if the request is not allowed + rx.await??; + (RequestUpdates::Active(tx), id) + } } - }) + } else { + (RequestUpdates::None, (0, 0)) + })) } - pub async fn transfer_aborted(&self, f: impl Fn() -> TransferAborted) -> ClientResult { - Ok(if let Some(client) = &self.inner { - match self.mask.transfer_aborted { - EventMode2::Notify => client.notify(f()).await?, - EventMode2::None => {} - } - }) + fn into_tracker( + &self, + (updates, (connection_id, request_id)): (RequestUpdates, (u64, u64)), + ) -> RequestTracker { + let throttle = match self.mask.throttle { + ThrottleMode::None => None, + ThrottleMode::Throttle => self + .inner + .clone() + .map(|client| (client, connection_id, request_id)), + }; + RequestTracker::new(updates, throttle) } } @@ -261,48 +364,213 @@ impl Client { pub enum ProviderProto { /// A new client connected to the provider. #[rpc(tx = oneshot::Sender)] - #[wrap(ClientConnected)] - ClientConnected { connection_id: u64, node_id: NodeId }, + ClientConnected(ClientConnected), /// A new client connected to the provider. Notify variant. #[rpc(tx = NoSender)] ClientConnectedNotify(Notify), + /// A client disconnected from the provider. #[rpc(tx = NoSender)] #[wrap(ConnectionClosed)] ConnectionClosed { connection_id: u64 }, - #[rpc(tx = oneshot::Sender)] + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] /// A new get request was received from the provider. GetRequestReceived(GetRequestReceived), - #[rpc(tx = NoSender)] + #[rpc(rx = mpsc::Receiver, tx = NoSender)] /// A new get request was received from the provider. GetRequestReceivedNotify(Notify), + /// A new get request was received from the provider. - #[rpc(tx = oneshot::Sender)] + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] GetManyRequestReceived(GetManyRequestReceived), + /// A new get request was received from the provider. - #[rpc(tx = NoSender)] + #[rpc(rx = mpsc::Receiver, tx = NoSender)] GetManyRequestReceivedNotify(Notify), + /// A new get request was received from the provider. - #[rpc(tx = oneshot::Sender)] + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] PushRequestReceived(PushRequestReceived), + /// A new get request was received from the provider. - #[rpc(tx = NoSender)] + #[rpc(rx = mpsc::Receiver, tx = NoSender)] PushRequestReceivedNotify(Notify), - /// Transfer for the nth blob started. - #[rpc(tx = NoSender)] - TransferStarted(TransferStarted), - /// Progress of the transfer. + #[rpc(tx = oneshot::Sender)] - TransferProgress(TransferProgress), - /// Progress of the transfer. - #[rpc(tx = NoSender)] - TransferProgressNotify(Notify), - /// Entire transfer completed. - #[rpc(tx = NoSender)] - TransferCompleted(TransferCompleted), - /// Entire transfer aborted. - #[rpc(tx = NoSender)] - TransferAborted(TransferAborted), + Throttle(Throttle), +} + +trait Request { + fn id(&self) -> (u64, u64); +} + +mod proto { + use iroh::NodeId; + use serde::{Deserialize, Serialize}; + + use super::Request; + use crate::{protocol::ChunkRangesSeq, provider::TransferStats, Hash}; + + #[derive(Debug, Serialize, Deserialize)] + pub struct ClientConnected { + pub connection_id: u64, + pub node_id: NodeId, + } + + /// A new get request was received from the provider. + #[derive(Debug, Serialize, Deserialize)] + pub struct GetRequestReceived { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The root hash of the request. + pub hash: Hash, + /// The exact query ranges of the request. + pub ranges: ChunkRangesSeq, + } + + impl Request for GetRequestReceived { + fn id(&self) -> (u64, u64) { + (self.connection_id, self.request_id) + } + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct GetManyRequestReceived { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The root hash of the request. + pub hashes: Vec, + /// The exact query ranges of the request. + pub ranges: ChunkRangesSeq, + } + + impl Request for GetManyRequestReceived { + fn id(&self) -> (u64, u64) { + (self.connection_id, self.request_id) + } + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct PushRequestReceived { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The root hash of the request. + pub hash: Hash, + /// The exact query ranges of the request. + pub ranges: ChunkRangesSeq, + } + + impl Request for PushRequestReceived { + fn id(&self) -> (u64, u64) { + (self.connection_id, self.request_id) + } + } + + /// Request to throttle sending for a specific request. + #[derive(Debug, Serialize, Deserialize)] + pub struct Throttle { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct TransferProgress { + /// The end offset of the chunk that was sent. + pub end_offset: u64, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct TransferStarted { + pub index: u64, + pub size: u64, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct TransferCompleted { + pub stats: Box, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct TransferAborted { + pub stats: Option>, + } + + /// Stream of updates for a single request + #[derive(Debug, Serialize, Deserialize)] + pub enum RequestUpdate { + /// Start of transfer for a blob, mandatory event + Started(TransferStarted), + /// Progress for a blob - optional event + Progress(TransferProgress), + /// Successful end of transfer + Completed(TransferCompleted), + /// Aborted end of transfer + Aborted(TransferAborted), + } +} +use proto::*; + +mod irpc_ext { + use std::future::Future; + + use irpc::{ + channel::{mpsc, none::NoSender, oneshot}, + Channels, RpcMessage, Service, WithChannels, + }; + + pub trait IrpcClientExt { + fn notify_streaming( + &self, + msg: Req, + local_update_cap: usize, + ) -> impl Future>> + where + S: From, + S::Message: From>, + Req: Channels>, + Update: RpcMessage; + } + + impl IrpcClientExt for irpc::Client { + fn notify_streaming( + &self, + msg: Req, + local_update_cap: usize, + ) -> impl Future>> + where + S: From, + S::Message: From>, + Req: Channels>, + Update: RpcMessage, + { + let client = self.clone(); + async move { + let request = client.request().await?; + match request { + irpc::Request::Local(local) => { + let (req_tx, req_rx) = mpsc::channel(local_update_cap); + local + .send((msg, NoSender, req_rx)) + .await + .map_err(irpc::Error::from)?; + Ok(req_tx) + } + irpc::Request::Remote(remote) => { + let (s, r) = remote.write(msg).await?; + Ok(s.into()) + } + } + } + } + } } From a78c212e8f307210bed17d7ab85010aeff9dfa5d Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 29 Aug 2025 15:05:26 +0200 Subject: [PATCH 04/35] Update tests --- examples/custom-protocol.rs | 4 +- examples/mdns-discovery.rs | 4 +- examples/random_store.rs | 104 +++--- examples/transfer.rs | 4 +- src/api/blobs.rs | 17 +- src/net_protocol.rs | 8 +- src/provider.rs | 354 ++++++++------------- src/provider/{event_proto.rs => events.rs} | 47 ++- src/tests.rs | 52 +-- 9 files changed, 261 insertions(+), 333 deletions(-) rename src/provider/{event_proto.rs => events.rs} (94%) diff --git a/examples/custom-protocol.rs b/examples/custom-protocol.rs index c021b7f0a..6542acd18 100644 --- a/examples/custom-protocol.rs +++ b/examples/custom-protocol.rs @@ -48,7 +48,7 @@ use iroh::{ protocol::{AcceptError, ProtocolHandler, Router}, Endpoint, NodeId, }; -use iroh_blobs::{api::Store, store::mem::MemStore, BlobsProtocol, Hash}; +use iroh_blobs::{api::Store, provider::EventSender2, store::mem::MemStore, BlobsProtocol, Hash}; mod common; use common::{get_or_generate_secret_key, setup_logging}; @@ -100,7 +100,7 @@ async fn listen(text: Vec) -> Result<()> { proto.insert_and_index(text).await?; } // Build the iroh-blobs protocol handler, which is used to download blobs. - let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); // create a router that handles both our custom protocol and the iroh-blobs protocol. let node = Router::builder(endpoint) diff --git a/examples/mdns-discovery.rs b/examples/mdns-discovery.rs index b42f88f47..ef5d0619c 100644 --- a/examples/mdns-discovery.rs +++ b/examples/mdns-discovery.rs @@ -18,7 +18,7 @@ use clap::{Parser, Subcommand}; use iroh::{ discovery::mdns::MdnsDiscovery, protocol::Router, Endpoint, PublicKey, RelayMode, SecretKey, }; -use iroh_blobs::{store::mem::MemStore, BlobsProtocol, Hash}; +use iroh_blobs::{provider::EventSender2, store::mem::MemStore, BlobsProtocol, Hash}; mod common; use common::{get_or_generate_secret_key, setup_logging}; @@ -68,7 +68,7 @@ async fn accept(path: &Path) -> Result<()> { .await?; let builder = Router::builder(endpoint.clone()); let store = MemStore::new(); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); let builder = builder.accept(iroh_blobs::ALPN, blobs.clone()); let node = builder.spawn(); diff --git a/examples/random_store.rs b/examples/random_store.rs index ffdd9b826..6f933d511 100644 --- a/examples/random_store.rs +++ b/examples/random_store.rs @@ -6,7 +6,7 @@ use iroh::{SecretKey, Watcher}; use iroh_base::ticket::NodeTicket; use iroh_blobs::{ api::downloader::Shuffled, - provider::Event, + provider::{AbortReason, Event, EventMask, EventSender2, ProviderMessage}, store::fs::FsStore, test::{add_hash_sequences, create_random_blobs}, HashAndFormat, @@ -104,78 +104,66 @@ pub fn dump_provider_events( allow_push: bool, ) -> ( tokio::task::JoinHandle<()>, - mpsc::Sender, + EventSender2, ) { let (tx, mut rx) = mpsc::channel(100); let dump_task = tokio::spawn(async move { while let Some(event) = rx.recv().await { match event { - Event::ClientConnected { - node_id, - connection_id, - permitted, - } => { - permitted.send(true).await.ok(); - println!("Client connected: {node_id} {connection_id}"); + ProviderMessage::ClientConnected(msg) => { + println!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); } - Event::GetRequestReceived { - connection_id, - request_id, - hash, - ranges, - } => { - println!( - "Get request received: {connection_id} {request_id} {hash} {ranges:?}" - ); + ProviderMessage::ClientConnectedNotify(msg) => { + println!("{:?}", msg.inner); } - Event::TransferCompleted { - connection_id, - request_id, - stats, - } => { - println!("Transfer completed: {connection_id} {request_id} {stats:?}"); + ProviderMessage::ConnectionClosed(msg) => { + println!("{:?}", msg.inner); } - Event::TransferAborted { - connection_id, - request_id, - stats, - } => { - println!("Transfer aborted: {connection_id} {request_id} {stats:?}"); + ProviderMessage::GetRequestReceived(mut msg) => { + println!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + tokio::spawn(async move { + while let Ok(update) = msg.rx.recv().await { + info!("{update:?}"); + } + }); } - Event::TransferProgress { - connection_id, - request_id, - index, - end_offset, - } => { - info!("Transfer progress: {connection_id} {request_id} {index} {end_offset}"); + ProviderMessage::GetRequestReceivedNotify(msg) => { + println!("{:?}", msg.inner); } - Event::PushRequestReceived { - connection_id, - request_id, - hash, - ranges, - permitted, - } => { - if allow_push { - permitted.send(true).await.ok(); - println!( - "Push request received: {connection_id} {request_id} {hash} {ranges:?}" - ); + ProviderMessage::GetManyRequestReceived(mut msg) => { + println!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + tokio::spawn(async move { + while let Ok(update) = msg.rx.recv().await { + info!("{update:?}"); + } + }); + } + ProviderMessage::GetManyRequestReceivedNotify(msg) => { + println!("{:?}", msg.inner); + } + ProviderMessage::PushRequestReceived(msg) => { + println!("{:?}", msg.inner); + let res = if allow_push { + Ok(()) } else { - permitted.send(false).await.ok(); - println!( - "Push request denied: {connection_id} {request_id} {hash} {ranges:?}" - ); - } + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); + } + ProviderMessage::PushRequestReceivedNotify(msg) => { + println!("{:?}", msg.inner); } - _ => { - info!("Received event: {:?}", event); + ProviderMessage::Throttle(msg) => { + println!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); } } } }); - (dump_task, tx) + (dump_task, EventSender2::new(tx, EventMask::ALL)) } #[tokio::main] @@ -237,7 +225,7 @@ async fn provide(args: ProvideArgs) -> anyhow::Result<()> { .bind() .await?; let (dump_task, events_tx) = dump_provider_events(args.allow_push); - let blobs = iroh_blobs::BlobsProtocol::new(&store, endpoint.clone(), Some(events_tx)); + let blobs = iroh_blobs::BlobsProtocol::new(&store, endpoint.clone(), events_tx); let router = iroh::protocol::Router::builder(endpoint.clone()) .accept(iroh_blobs::ALPN, blobs) .spawn(); diff --git a/examples/transfer.rs b/examples/transfer.rs index 48fba6ba3..baa1e343c 100644 --- a/examples/transfer.rs +++ b/examples/transfer.rs @@ -1,7 +1,7 @@ use std::path::PathBuf; use iroh::{protocol::Router, Endpoint}; -use iroh_blobs::{store::mem::MemStore, ticket::BlobTicket, BlobsProtocol}; +use iroh_blobs::{provider::EventSender2, store::mem::MemStore, ticket::BlobTicket, BlobsProtocol}; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -12,7 +12,7 @@ async fn main() -> anyhow::Result<()> { // We initialize an in-memory backing store for iroh-blobs let store = MemStore::new(); // Then we initialize a struct that can accept blobs requests over iroh connections - let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); // Grab all passed in arguments, the first one is the binary itself, so we skip it. let args: Vec = std::env::args().skip(1).collect(); diff --git a/src/api/blobs.rs b/src/api/blobs.rs index d0b948598..76f338359 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -57,7 +57,7 @@ use super::{ }; use crate::{ api::proto::{BatchRequest, ImportByteStreamUpdate}, - provider::StreamContext, + provider::{ReaderContext, WriterContext}, store::IROH_BLOCK_SIZE, util::temp_tag::TempTag, BlobFormat, Hash, HashAndFormat, @@ -1168,16 +1168,21 @@ pub(crate) trait WriteProgress { async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64); } -impl WriteProgress for StreamContext { - async fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize) { - StreamContext::notify_payload_write(self, index, offset, len); +impl WriteProgress for WriterContext { + async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) { + let end_offset = offset + len as u64; + self.payload_bytes_written += len as u64; + self.tracker.transfer_progress(end_offset).await.ok(); } fn log_other_write(&mut self, len: usize) { - StreamContext::log_other_write(self, len); + self.other_bytes_written += len as u64; } async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) { - StreamContext::send_transfer_started(self, index, hash, size).await + self.tracker + .transfer_started(index, hash, size) + .await + .ok(); } } diff --git a/src/net_protocol.rs b/src/net_protocol.rs index 3e7d9582e..ca64b1a7b 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -48,7 +48,7 @@ use tracing::error; use crate::{ api::Store, - provider::{Event, EventSender}, + provider::{Event, EventSender2}, ticket::BlobTicket, HashAndFormat, }; @@ -57,7 +57,7 @@ use crate::{ pub(crate) struct BlobsInner { pub(crate) store: Store, pub(crate) endpoint: Endpoint, - pub(crate) events: EventSender, + pub(crate) events: EventSender2, } /// A protocol handler for the blobs protocol. @@ -75,12 +75,12 @@ impl Deref for BlobsProtocol { } impl BlobsProtocol { - pub fn new(store: &Store, endpoint: Endpoint, events: Option>) -> Self { + pub fn new(store: &Store, endpoint: Endpoint, events: EventSender2) -> Self { Self { inner: Arc::new(BlobsInner { store: store.clone(), endpoint, - events: EventSender::new(events), + events, }), } } diff --git a/src/provider.rs b/src/provider.rs index 141d674c6..b10367911 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -9,7 +9,7 @@ use std::{ ops::{Deref, DerefMut}, pin::Pin, task::Poll, - time::Duration, + time::{Duration, Instant}, }; use anyhow::{Context, Result}; @@ -18,22 +18,23 @@ use iroh::{ endpoint::{self, RecvStream, SendStream}, NodeId, }; -use irpc::channel::oneshot; +use irpc::{channel::oneshot}; use n0_future::StreamExt; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::{io::AsyncRead, select, sync::mpsc}; -use tracing::{debug, debug_span, error, warn, Instrument}; +use tracing::{debug, debug_span, error, trace, warn, Instrument}; use crate::{ - api::{self, blobs::Bitfield, Store}, - hashseq::HashSeq, - protocol::{ + api::{self, blobs::{Bitfield, WriteProgress}, Store}, hashseq::HashSeq, protocol::{ ChunkRangesSeq, GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, - }, - Hash, + }, provider::events::{ClientConnected, ConnectionClosed, GetRequestReceived, RequestTracker}, Hash }; -mod event_proto; +pub(crate) mod events; +pub use events::EventSender as EventSender2; +pub use events::ProviderMessage; +pub use events::EventMask; +pub use events::AbortReason; /// Provider progress events, to keep track of what the provider is doing. /// @@ -162,19 +163,33 @@ pub async fn read_request(reader: &mut ProgressReader) -> Result { } #[derive(Debug)] -pub struct StreamContext { +pub struct ReaderContext { + /// The start time of the transfer + pub t0: Instant, /// The connection ID from the connection pub connection_id: u64, /// The request ID from the recv stream pub request_id: u64, - /// The number of bytes written that are part of the payload - pub payload_bytes_sent: u64, - /// The number of bytes written that are not part of the payload - pub other_bytes_sent: u64, /// The number of bytes read from the stream pub bytes_read: u64, - /// The progress sender to send events to - pub progress: EventSender, +} + +#[derive(Debug)] +pub struct WriterContext { + /// The start time of the transfer + pub t0: Instant, + /// The connection ID from the connection + pub connection_id: u64, + /// The request ID from the recv stream + pub request_id: u64, + /// The number of bytes read from the stream + pub bytes_read: u64, + /// The number of payload bytes written to the stream + pub payload_bytes_written: u64, + /// The number of bytes written that are not part of the payload + pub other_bytes_written: u64, + /// Way to report progress + pub tracker: RequestTracker, } /// Wrapper for a [`quinn::SendStream`] with additional per request information. @@ -182,11 +197,43 @@ pub struct StreamContext { pub struct ProgressWriter { /// The quinn::SendStream to write to pub inner: SendStream, - pub(crate) context: StreamContext, + pub(crate) context: WriterContext, +} + +impl ProgressWriter { + fn new(inner: SendStream, context: ReaderContext, tracker: RequestTracker) -> Self { + Self { inner, context: WriterContext { + connection_id: context.connection_id, + request_id: context.request_id, + bytes_read: context.bytes_read, + t0: context.t0, + payload_bytes_written: 0, + other_bytes_written: 0, + tracker, + } } + } + + async fn transfer_aborted(&self) { + self.tracker.transfer_aborted(|| Some(Box::new(TransferStats { + payload_bytes_sent: self.payload_bytes_written, + other_bytes_sent: self.other_bytes_written, + bytes_read: self.bytes_read, + duration: self.context.t0.elapsed(), + }))).await.ok(); + } + + async fn transfer_completed(&self) { + self.tracker.transfer_completed(|| Box::new(TransferStats { + payload_bytes_sent: self.payload_bytes_written, + other_bytes_sent: self.other_bytes_written, + bytes_read: self.bytes_read, + duration: self.context.t0.elapsed(), + })).await.ok(); + } } impl Deref for ProgressWriter { - type Target = StreamContext; + type Target = WriterContext; fn deref(&self) -> &Self::Target { &self.context @@ -199,140 +246,11 @@ impl DerefMut for ProgressWriter { } } -impl StreamContext { - /// Increase the write count due to a non-payload write. - pub fn log_other_write(&mut self, len: usize) { - self.other_bytes_sent += len as u64; - } - - pub async fn send_transfer_completed(&mut self) { - self.progress - .send(|| Event::TransferCompleted { - connection_id: self.connection_id, - request_id: self.request_id, - stats: Box::new(TransferStats { - payload_bytes_sent: self.payload_bytes_sent, - other_bytes_sent: self.other_bytes_sent, - bytes_read: self.bytes_read, - duration: Duration::ZERO, - }), - }) - .await; - } - - pub async fn send_transfer_aborted(&mut self) { - self.progress - .send(|| Event::TransferAborted { - connection_id: self.connection_id, - request_id: self.request_id, - stats: Some(Box::new(TransferStats { - payload_bytes_sent: self.payload_bytes_sent, - other_bytes_sent: self.other_bytes_sent, - bytes_read: self.bytes_read, - duration: Duration::ZERO, - })), - }) - .await; - } - - /// Increase the write count due to a payload write, and notify the progress sender. - /// - /// `index` is the index of the blob in the request. - /// `offset` is the offset in the blob where the write started. - /// `len` is the length of the write. - pub fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize) { - self.payload_bytes_sent += len as u64; - self.progress.try_send(|| Event::TransferProgress { - connection_id: self.connection_id, - request_id: self.request_id, - index, - end_offset: offset + len as u64, - }); - } - - /// Send a get request received event. - /// - /// This sends all the required information to make sense of subsequent events such as - /// [`Event::TransferStarted`] and [`Event::TransferProgress`]. - pub async fn send_get_request_received(&self, hash: &Hash, ranges: &ChunkRangesSeq) { - self.progress - .send(|| Event::GetRequestReceived { - connection_id: self.connection_id, - request_id: self.request_id, - hash: *hash, - ranges: ranges.clone(), - }) - .await; - } - - /// Send a get request received event. - /// - /// This sends all the required information to make sense of subsequent events such as - /// [`Event::TransferStarted`] and [`Event::TransferProgress`]. - pub async fn send_get_many_request_received(&self, hashes: &[Hash], ranges: &ChunkRangesSeq) { - self.progress - .send(|| Event::GetManyRequestReceived { - connection_id: self.connection_id, - request_id: self.request_id, - hashes: hashes.to_vec(), - ranges: ranges.clone(), - }) - .await; - } - - /// Authorize a push request. - /// - /// This will send a request to the event sender, and wait for a response if a - /// progress sender is enabled. If not, it will always fail. - /// - /// We want to make accepting push requests very explicit, since this allows - /// remote nodes to add arbitrary data to our store. - #[must_use = "permit should be checked by the caller"] - pub async fn authorize_push_request(&self, hash: &Hash, ranges: &ChunkRangesSeq) -> bool { - let mut wait_for_permit = None; - // send the request, including the permit channel - self.progress - .send(|| { - let (tx, rx) = oneshot::channel(); - wait_for_permit = Some(rx); - Event::PushRequestReceived { - connection_id: self.connection_id, - request_id: self.request_id, - hash: *hash, - ranges: ranges.clone(), - permitted: tx, - } - }) - .await; - // wait for the permit, if necessary - if let Some(wait_for_permit) = wait_for_permit { - // if somebody does not handle the request, they will drop the channel, - // and this will fail immediately. - wait_for_permit.await.unwrap_or(false) - } else { - false - } - } - - /// Send a transfer started event. - pub async fn send_transfer_started(&self, index: u64, hash: &Hash, size: u64) { - self.progress - .send(|| Event::TransferStarted { - connection_id: self.connection_id, - request_id: self.request_id, - index, - hash: *hash, - size, - }) - .await; - } -} - /// Handle a single connection. pub async fn handle_connection( connection: endpoint::Connection, store: Store, - progress: EventSender, + progress: EventSender2, ) { let connection_id = connection.stable_id() as u64; let span = debug_span!("connection", connection_id); @@ -341,11 +259,14 @@ pub async fn handle_connection( warn!("failed to get node id"); return; }; - if !progress - .authorize_client_connection(connection_id, node_id) + if let Err(cause) =progress + .client_connected(|| ClientConnected { + connection_id, + node_id, + }) .await { - debug!("client not authorized to connect"); + debug!("client not authorized to connect: {cause}"); return; } while let Ok((writer, reader)) = connection.accept_bi().await { @@ -354,35 +275,24 @@ pub async fn handle_connection( let request_id = reader.id().index(); let span = debug_span!("stream", stream_id = %request_id); let store = store.clone(); - let mut writer = ProgressWriter { - inner: writer, - context: StreamContext { - connection_id, - request_id, - payload_bytes_sent: 0, - other_bytes_sent: 0, - bytes_read: 0, - progress: progress.clone(), - }, + let context = ReaderContext { + t0: Instant::now(), + connection_id: connection_id, + request_id: request_id, + bytes_read: 0, + }; + let reader = ProgressReader { + inner: reader, + context, }; tokio::spawn( - async move { - match handle_stream(store, reader, &mut writer).await { - Ok(()) => { - writer.send_transfer_completed().await; - } - Err(err) => { - warn!("error: {err:#?}",); - writer.send_transfer_aborted().await; - } - } - } + handle_stream(store, reader, writer, progress.clone()) .instrument(span), ); } progress - .send(Event::ConnectionClosed { connection_id }) - .await; + .connection_closed(|| ConnectionClosed { connection_id }) + .await.ok(); } .instrument(span) .await @@ -390,56 +300,69 @@ pub async fn handle_connection( async fn handle_stream( store: Store, - reader: RecvStream, - writer: &mut ProgressWriter, -) -> Result<()> { + mut reader: ProgressReader, + writer: SendStream, + progress: EventSender2, +) { // 1. Decode the request. debug!("reading request"); - let mut reader = ProgressReader { - inner: reader, - context: StreamContext { - connection_id: writer.connection_id, - request_id: writer.request_id, - payload_bytes_sent: 0, - other_bytes_sent: 0, - bytes_read: 0, - progress: writer.progress.clone(), - }, - }; let request = match read_request(&mut reader).await { Ok(request) => request, Err(e) => { - // todo: increase invalid requests metric counter - return Err(e); + // todo: event for read request failed + return; } }; match request { Request::Get(request) => { + let tracker = match progress.get_request(|| GetRequestReceived { + connection_id: reader.context.connection_id, + request_id: reader.context.request_id, + request: request.clone(), + }).await { + Ok(tracker) => tracker, + Err(e) => { + trace!("Request denied: {}", e); + return; + } + }; // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - reader.inner.read_to_end(0).await?; - // move the context so we don't lose the bytes read - writer.context = reader.context; - handle_get(store, request, writer).await + let res = reader.inner.read_to_end(0).await; + let mut writer = ProgressWriter::new(writer, reader.context, tracker); + if res.is_err() { + writer.transfer_aborted().await; + return; + } + match handle_get(store, request, &mut writer).await { + Ok(()) => { + writer.transfer_completed().await; + } + Err(_) => { + writer.transfer_aborted().await; + } + } } Request::GetMany(request) => { - // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - reader.inner.read_to_end(0).await?; - // move the context so we don't lose the bytes read - writer.context = reader.context; - handle_get_many(store, request, writer).await + todo!(); + // // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. + // reader.inner.read_to_end(0).await?; + // // move the context so we don't lose the bytes read + // writer.context = reader.context; + // handle_get_many(store, request, writer).await } Request::Observe(request) => { - // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - reader.inner.read_to_end(0).await?; - handle_observe(store, request, writer).await + todo!(); + // // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. + // reader.inner.read_to_end(0).await?; + // handle_observe(store, request, writer).await } Request::Push(request) => { - writer.inner.finish()?; - handle_push(store, request, reader).await + todo!(); + // writer.inner.finish()?; + // handle_push(store, request, reader).await } - _ => anyhow::bail!("unsupported request: {request:?}"), - // Request::Push(request) => handle_push(store, request, writer).await, + _ => {}, } } @@ -450,13 +373,9 @@ pub async fn handle_get( store: Store, request: GetRequest, writer: &mut ProgressWriter, -) -> Result<()> { +) -> anyhow::Result<()> { let hash = request.hash; debug!(%hash, "get received request"); - - writer - .send_get_request_received(&hash, &request.ranges) - .await; let mut hash_seq = None; for (offset, ranges) in request.ranges.iter_non_empty_infinite() { if offset == 0 { @@ -496,9 +415,6 @@ pub async fn handle_get_many( writer: &mut ProgressWriter, ) -> Result<()> { debug!("get_many received request"); - writer - .send_get_many_request_received(&request.hashes, &request.ranges) - .await; let request_ranges = request.ranges.iter_infinite(); for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() { if !ranges.is_empty() { @@ -518,10 +434,6 @@ pub async fn handle_push( ) -> Result<()> { let hash = request.hash; debug!(%hash, "push received request"); - if !reader.authorize_push_request(&hash, &request.ranges).await { - debug!("push request not authorized"); - return Ok(()); - }; let mut request_ranges = request.ranges.iter_infinite(); let root_ranges = request_ranges.next().expect("infinite iterator"); if !root_ranges.is_empty() { @@ -602,7 +514,7 @@ async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> Resu use irpc::util::AsyncWriteVarintExt; let item = ObserveItem::from(item); let len = writer.inner.write_length_prefixed(item).await?; - writer.log_other_write(len); + writer.context.log_other_write(len); Ok(()) } @@ -701,11 +613,11 @@ impl EventSender { pub struct ProgressReader { inner: RecvStream, - context: StreamContext, + context: ReaderContext, } impl Deref for ProgressReader { - type Target = StreamContext; + type Target = ReaderContext; fn deref(&self) -> &Self::Target { &self.context diff --git a/src/provider/event_proto.rs b/src/provider/events.rs similarity index 94% rename from src/provider/event_proto.rs rename to src/provider/events.rs index 8713d8aac..55383e77c 100644 --- a/src/provider/event_proto.rs +++ b/src/provider/events.rs @@ -7,7 +7,7 @@ use irpc::{ use serde::{Deserialize, Serialize}; use snafu::Snafu; -use crate::provider::{event_proto::irpc_ext::IrpcClientExt, TransferStats}; +use crate::{provider::{events::irpc_ext::IrpcClientExt, TransferStats}, Hash}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[repr(u8)] @@ -136,7 +136,7 @@ impl EventMask { pub struct Notify(T); #[derive(Debug, Default, Clone)] -pub struct Client { +pub struct EventSender { mask: EventMask, inner: Option>, } @@ -150,9 +150,10 @@ enum RequestUpdates { Active(mpsc::Sender), /// Disabled request tracking, we just hold on to the sender so it drops /// once the request is completed or aborted. - Disabled(mpsc::Sender), + Disabled(#[allow(dead_code)] mpsc::Sender), } +#[derive(Debug)] pub struct RequestTracker { updates: RequestUpdates, throttle: Option<(irpc::Client, u64, u64)>, @@ -173,9 +174,9 @@ impl RequestTracker { }; /// Transfer for index `index` started, size `size` - pub async fn transfer_started(&self, index: u64, size: u64) -> irpc::Result<()> { + pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { - tx.send(RequestUpdate::Started(TransferStarted { index, size })) + tx.send(RequestUpdate::Started(TransferStarted { index, hash: *hash, size })) .await?; } Ok(()) @@ -200,7 +201,7 @@ impl RequestTracker { /// Transfer completed for the previously reported blob. pub async fn transfer_completed( - &mut self, + &self, f: impl Fn() -> Box, ) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { @@ -212,7 +213,7 @@ impl RequestTracker { /// Transfer aborted for the previously reported blob. pub async fn transfer_aborted( - &mut self, + &self, f: impl Fn() -> Option>, ) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { @@ -227,13 +228,17 @@ impl RequestTracker { /// /// For most event types, the client can be configured to either send notifications or requests that /// can have a response. -impl Client { +impl EventSender { /// A client that does not send anything. pub const NONE: Self = Self { mask: EventMask::NONE, inner: None, }; + pub fn new(client: tokio::sync::mpsc::Sender, mask: EventMask) -> Self { + Self { mask, inner: Some(irpc::Client::from(client)) } + } + /// A new client has been connected. pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult { Ok(if let Some(client) = &self.inner { @@ -245,6 +250,13 @@ impl Client { }) } + /// A new client has been connected. + pub async fn connection_closed(&self, f: impl Fn() -> ConnectionClosed) -> ClientResult { + Ok(if let Some(client) = &self.inner { + client.notify(f()).await?; + }) + } + /// Start a get request. You will get back either an error if the request should not proceed, or a /// [`RequestTracker`] that you can use to log progress for this particular request. /// @@ -371,8 +383,7 @@ pub enum ProviderProto { /// A client disconnected from the provider. #[rpc(tx = NoSender)] - #[wrap(ConnectionClosed)] - ConnectionClosed { connection_id: u64 }, + ConnectionClosed(ConnectionClosed), #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] /// A new get request was received from the provider. @@ -411,7 +422,7 @@ mod proto { use serde::{Deserialize, Serialize}; use super::Request; - use crate::{protocol::ChunkRangesSeq, provider::TransferStats, Hash}; + use crate::{protocol::{ChunkRangesSeq, GetRequest}, provider::TransferStats, Hash}; #[derive(Debug, Serialize, Deserialize)] pub struct ClientConnected { @@ -419,6 +430,11 @@ mod proto { pub node_id: NodeId, } + #[derive(Debug, Serialize, Deserialize)] + pub struct ConnectionClosed { + pub connection_id: u64, + } + /// A new get request was received from the provider. #[derive(Debug, Serialize, Deserialize)] pub struct GetRequestReceived { @@ -426,10 +442,8 @@ mod proto { pub connection_id: u64, /// The request id. There is a new id for each request. pub request_id: u64, - /// The root hash of the request. - pub hash: Hash, - /// The exact query ranges of the request. - pub ranges: ChunkRangesSeq, + /// The request + pub request: GetRequest, } impl Request for GetRequestReceived { @@ -492,6 +506,7 @@ mod proto { #[derive(Debug, Serialize, Deserialize)] pub struct TransferStarted { pub index: u64, + pub hash: Hash, pub size: u64, } @@ -518,7 +533,7 @@ mod proto { Aborted(TransferAborted), } } -use proto::*; +pub use proto::*; mod irpc_ext { use std::future::Future; diff --git a/src/tests.rs b/src/tests.rs index e7dc823e6..9b825bd08 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -16,7 +16,7 @@ use crate::{ hashseq::HashSeq, net_protocol::BlobsProtocol, protocol::{ChunkRangesSeq, GetManyRequest, ObserveRequest, PushRequest}, - provider::Event, + provider::{events::{AbortReason, RequestUpdate}, Event, EventMask, EventSender2, ProviderMessage}, store::{ fs::{ tests::{create_n0_bao, test_data, INTERESTING_SIZES}, @@ -341,32 +341,40 @@ async fn two_nodes_get_many_mem() -> TestResult<()> { fn event_handler( allowed_nodes: impl IntoIterator, ) -> ( - mpsc::Sender, + EventSender2, watch::Receiver, AbortOnDropHandle<()>, ) { let (count_tx, count_rx) = tokio::sync::watch::channel(0usize); - let (events_tx, mut events_rx) = mpsc::channel::(16); + let (events_tx, mut events_rx) = mpsc::channel::(16); let allowed_nodes = allowed_nodes.into_iter().collect::>(); let task = AbortOnDropHandle::new(tokio::task::spawn(async move { while let Some(event) = events_rx.recv().await { match event { - Event::ClientConnected { - node_id, permitted, .. - } => { - permitted.send(allowed_nodes.contains(&node_id)).await.ok(); + ProviderMessage::ClientConnected(msg) => { + let res = if allowed_nodes.contains(&msg.inner.node_id) { + Ok(()) + } else { + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); } - Event::PushRequestReceived { permitted, .. } => { - permitted.send(true).await.ok(); - } - Event::TransferCompleted { .. } => { - count_tx.send_modify(|count| *count += 1); + ProviderMessage::PushRequestReceived(mut msg) => { + msg.tx.send(Ok(())).await.ok(); + let count_tx = count_tx.clone(); + tokio::task::spawn(async move { + while let Ok(Some(update)) = msg.rx.recv().await { + if let RequestUpdate::Completed(_) = update { + count_tx.send_modify(|x| *x += 1); + } + } + }); } _ => {} } } })); - (events_tx, count_rx, task) + (EventSender2::new(events_tx, EventMask::ALL), count_rx, task) } async fn two_nodes_push_blobs( @@ -409,7 +417,7 @@ async fn two_nodes_push_blobs_fs() -> TestResult<()> { let (r1, store1, _) = node_test_setup_fs(testdir.path().join("a")).await?; let (events_tx, count_rx, _task) = event_handler([r1.endpoint().node_id()]); let (r2, store2, _) = - node_test_setup_with_events_fs(testdir.path().join("b"), Some(events_tx)).await?; + node_test_setup_with_events_fs(testdir.path().join("b"), events_tx).await?; two_nodes_push_blobs(r1, &store1, r2, &store2, count_rx).await } @@ -418,7 +426,7 @@ async fn two_nodes_push_blobs_mem() -> TestResult<()> { tracing_subscriber::fmt::try_init().ok(); let (r1, store1) = node_test_setup_mem().await?; let (events_tx, count_rx, _task) = event_handler([r1.endpoint().node_id()]); - let (r2, store2) = node_test_setup_with_events_mem(Some(events_tx)).await?; + let (r2, store2) = node_test_setup_with_events_mem(events_tx).await?; two_nodes_push_blobs(r1, &store1, r2, &store2, count_rx).await } @@ -481,12 +489,12 @@ async fn check_presence(store: &Store, sizes: &[usize]) -> TestResult<()> { } pub async fn node_test_setup_fs(db_path: PathBuf) -> TestResult<(Router, FsStore, PathBuf)> { - node_test_setup_with_events_fs(db_path, None).await + node_test_setup_with_events_fs(db_path, EventSender2::NONE).await } pub async fn node_test_setup_with_events_fs( db_path: PathBuf, - events: Option>, + events: EventSender2, ) -> TestResult<(Router, FsStore, PathBuf)> { let store = crate::store::fs::FsStore::load(&db_path).await?; let ep = Endpoint::builder().bind().await?; @@ -496,11 +504,11 @@ pub async fn node_test_setup_with_events_fs( } pub async fn node_test_setup_mem() -> TestResult<(Router, MemStore)> { - node_test_setup_with_events_mem(None).await + node_test_setup_with_events_mem(EventSender2::NONE).await } pub async fn node_test_setup_with_events_mem( - events: Option>, + events: EventSender2, ) -> TestResult<(Router, MemStore)> { let store = MemStore::new(); let ep = Endpoint::builder().bind().await?; @@ -601,7 +609,7 @@ async fn node_serve_hash_seq() -> TestResult<()> { let root_tt = store.add_bytes(hash_seq).await?; let root = root_tt.hash; let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), None); + let blobs = crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); @@ -632,7 +640,7 @@ async fn node_serve_blobs() -> TestResult<()> { tts.push(store.add_bytes(test_data(size)).await?); } let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), None); + let blobs = crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); @@ -674,7 +682,7 @@ async fn node_smoke(store: &Store) -> TestResult<()> { let tt = store.add_bytes(b"hello world".to_vec()).temp_tag().await?; let hash = *tt.hash(); let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = crate::net_protocol::BlobsProtocol::new(store, endpoint.clone(), None); + let blobs = crate::net_protocol::BlobsProtocol::new(store, endpoint.clone(), EventSender2::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); From df1e1ef992bbc1b74167a4cf9b1671ca71ba08d5 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Mon, 1 Sep 2025 11:10:21 +0200 Subject: [PATCH 05/35] tests pass --- examples/random_store.rs | 9 +- src/api/blobs.rs | 7 +- src/net_protocol.rs | 8 +- src/provider.rs | 313 +++++++++++++++++++++++++++------------ src/provider/events.rs | 47 +++--- src/tests.rs | 20 +-- 6 files changed, 261 insertions(+), 143 deletions(-) diff --git a/examples/random_store.rs b/examples/random_store.rs index 6f933d511..5bc136d41 100644 --- a/examples/random_store.rs +++ b/examples/random_store.rs @@ -6,7 +6,7 @@ use iroh::{SecretKey, Watcher}; use iroh_base::ticket::NodeTicket; use iroh_blobs::{ api::downloader::Shuffled, - provider::{AbortReason, Event, EventMask, EventSender2, ProviderMessage}, + provider::{AbortReason, EventMask, EventSender2, ProviderMessage}, store::fs::FsStore, test::{add_hash_sequences, create_random_blobs}, HashAndFormat, @@ -100,12 +100,7 @@ pub fn get_or_generate_secret_key() -> Result { } } -pub fn dump_provider_events( - allow_push: bool, -) -> ( - tokio::task::JoinHandle<()>, - EventSender2, -) { +pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, EventSender2) { let (tx, mut rx) = mpsc::channel(100); let dump_task = tokio::spawn(async move { while let Some(event) = rx.recv().await { diff --git a/src/api/blobs.rs b/src/api/blobs.rs index 76f338359..d00a0a940 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -57,7 +57,7 @@ use super::{ }; use crate::{ api::proto::{BatchRequest, ImportByteStreamUpdate}, - provider::{ReaderContext, WriterContext}, + provider::WriterContext, store::IROH_BLOCK_SIZE, util::temp_tag::TempTag, BlobFormat, Hash, HashAndFormat, @@ -1180,9 +1180,6 @@ impl WriteProgress for WriterContext { } async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) { - self.tracker - .transfer_started(index, hash, size) - .await - .ok(); + self.tracker.transfer_started(index, hash, size).await.ok(); } } diff --git a/src/net_protocol.rs b/src/net_protocol.rs index ca64b1a7b..a1d6a1f5d 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -43,15 +43,9 @@ use iroh::{ protocol::{AcceptError, ProtocolHandler}, Endpoint, Watcher, }; -use tokio::sync::mpsc; use tracing::error; -use crate::{ - api::Store, - provider::{Event, EventSender2}, - ticket::BlobTicket, - HashAndFormat, -}; +use crate::{api::Store, provider::EventSender2, ticket::BlobTicket, HashAndFormat}; #[derive(Debug)] pub(crate) struct BlobsInner { diff --git a/src/provider.rs b/src/provider.rs index b10367911..2f7bd078f 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -18,23 +18,32 @@ use iroh::{ endpoint::{self, RecvStream, SendStream}, NodeId, }; -use irpc::{channel::oneshot}; +use irpc::channel::oneshot; use n0_future::StreamExt; +use quinn::{ClosedStream, ReadToEndError}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::{io::AsyncRead, select, sync::mpsc}; use tracing::{debug, debug_span, error, trace, warn, Instrument}; use crate::{ - api::{self, blobs::{Bitfield, WriteProgress}, Store}, hashseq::HashSeq, protocol::{ + api::{ + self, + blobs::{Bitfield, WriteProgress}, + Store, + }, + hashseq::HashSeq, + protocol::{ ChunkRangesSeq, GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, - }, provider::events::{ClientConnected, ConnectionClosed, GetRequestReceived, RequestTracker}, Hash + }, + provider::events::{ + ClientConnected, ConnectionClosed, GetManyRequestReceived, GetRequestReceived, + PushRequestReceived, RequestTracker, + }, + Hash, }; pub(crate) mod events; -pub use events::EventSender as EventSender2; -pub use events::ProviderMessage; -pub use events::EventMask; -pub use events::AbortReason; +pub use events::{AbortReason, EventMask, EventSender as EventSender2, ProviderMessage}; /// Provider progress events, to keep track of what the provider is doing. /// @@ -155,13 +164,71 @@ pub struct TransferStats { /// leave the rest of the stream for the caller to read. /// /// It is up to the caller do decide if there should be more data. -pub async fn read_request(reader: &mut ProgressReader) -> Result { - let mut counting = CountingReader::new(&mut reader.inner); +pub async fn read_request(context: &mut StreamData) -> Result { + let mut counting = CountingReader::new(&mut context.reader); let res = Request::read_async(&mut counting).await?; - reader.bytes_read += counting.read(); + context.bytes_read += counting.read(); Ok(res) } +#[derive(Debug)] +pub struct StreamData { + pub t0: Instant, + pub connection_id: u64, + pub request_id: u64, + pub reader: RecvStream, + pub writer: SendStream, + pub events: EventSender2, + pub bytes_read: u64, +} + +impl StreamData { + /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id + async fn into_writer( + mut self, + tracker: RequestTracker, + ) -> Result { + let res = self.reader.read_to_end(0).await; + if let Err(e) = res { + tracker.transfer_aborted(|| None).await.ok(); + return Err(e); + }; + Ok(ProgressWriter::new( + self.writer, + WriterContext { + t0: self.t0, + connection_id: self.connection_id, + request_id: self.request_id, + bytes_read: self.bytes_read, + payload_bytes_written: 0, + other_bytes_written: 0, + tracker, + }, + )) + } + + async fn into_reader( + mut self, + tracker: RequestTracker, + ) -> Result { + let res = self.writer.finish(); + if let Err(e) = res { + tracker.transfer_aborted(|| None).await.ok(); + return Err(e); + }; + Ok(ProgressReader::new( + self.reader, + ReaderContext { + t0: self.t0, + connection_id: self.connection_id, + request_id: self.request_id, + bytes_read: self.bytes_read, + tracker, + }, + )) + } +} + #[derive(Debug)] pub struct ReaderContext { /// The start time of the transfer @@ -172,6 +239,20 @@ pub struct ReaderContext { pub request_id: u64, /// The number of bytes read from the stream pub bytes_read: u64, + /// Progress tracking for the request + pub tracker: RequestTracker, +} + +impl ReaderContext { + pub fn new(context: StreamData, tracker: RequestTracker) -> Self { + Self { + t0: context.t0, + connection_id: context.connection_id, + request_id: context.request_id, + bytes_read: context.bytes_read, + tracker, + } + } } #[derive(Debug)] @@ -192,6 +273,20 @@ pub struct WriterContext { pub tracker: RequestTracker, } +impl WriterContext { + pub fn new(context: &StreamData, tracker: RequestTracker) -> Self { + Self { + t0: context.t0, + connection_id: context.connection_id, + request_id: context.request_id, + bytes_read: context.bytes_read, + payload_bytes_written: 0, + other_bytes_written: 0, + tracker, + } + } +} + /// Wrapper for a [`quinn::SendStream`] with additional per request information. #[derive(Debug)] pub struct ProgressWriter { @@ -201,34 +296,36 @@ pub struct ProgressWriter { } impl ProgressWriter { - fn new(inner: SendStream, context: ReaderContext, tracker: RequestTracker) -> Self { - Self { inner, context: WriterContext { - connection_id: context.connection_id, - request_id: context.request_id, - bytes_read: context.bytes_read, - t0: context.t0, - payload_bytes_written: 0, - other_bytes_written: 0, - tracker, - } } + fn new(inner: SendStream, context: WriterContext) -> Self { + Self { inner, context } } async fn transfer_aborted(&self) { - self.tracker.transfer_aborted(|| Some(Box::new(TransferStats { - payload_bytes_sent: self.payload_bytes_written, - other_bytes_sent: self.other_bytes_written, - bytes_read: self.bytes_read, - duration: self.context.t0.elapsed(), - }))).await.ok(); + self.tracker + .transfer_aborted(|| { + Some(Box::new(TransferStats { + payload_bytes_sent: self.payload_bytes_written, + other_bytes_sent: self.other_bytes_written, + bytes_read: self.bytes_read, + duration: self.context.t0.elapsed(), + })) + }) + .await + .ok(); } async fn transfer_completed(&self) { - self.tracker.transfer_completed(|| Box::new(TransferStats { - payload_bytes_sent: self.payload_bytes_written, - other_bytes_sent: self.other_bytes_written, - bytes_read: self.bytes_read, - duration: self.context.t0.elapsed(), - })).await.ok(); + self.tracker + .transfer_completed(|| { + Box::new(TransferStats { + payload_bytes_sent: self.payload_bytes_written, + other_bytes_sent: self.other_bytes_written, + bytes_read: self.bytes_read, + duration: self.context.t0.elapsed(), + }) + }) + .await + .ok(); } } @@ -259,7 +356,7 @@ pub async fn handle_connection( warn!("failed to get node id"); return; }; - if let Err(cause) =progress + if let Err(cause) = progress .client_connected(|| ClientConnected { connection_id, node_id, @@ -275,95 +372,87 @@ pub async fn handle_connection( let request_id = reader.id().index(); let span = debug_span!("stream", stream_id = %request_id); let store = store.clone(); - let context = ReaderContext { + let context = StreamData { t0: Instant::now(), connection_id: connection_id, request_id: request_id, + reader, + writer, + events: progress.clone(), bytes_read: 0, }; - let reader = ProgressReader { - inner: reader, - context, - }; - tokio::spawn( - handle_stream(store, reader, writer, progress.clone()) - .instrument(span), - ); + tokio::spawn(handle_stream(store, context).instrument(span)); } progress .connection_closed(|| ConnectionClosed { connection_id }) - .await.ok(); + .await + .ok(); } .instrument(span) .await } -async fn handle_stream( - store: Store, - mut reader: ProgressReader, - writer: SendStream, - progress: EventSender2, -) { +async fn handle_stream(store: Store, mut context: StreamData) -> anyhow::Result<()> { // 1. Decode the request. debug!("reading request"); - let request = match read_request(&mut reader).await { - Ok(request) => request, - Err(e) => { - // todo: event for read request failed - return; - } - }; + let request = read_request(&mut context).await?; match request { Request::Get(request) => { - let tracker = match progress.get_request(|| GetRequestReceived { - connection_id: reader.context.connection_id, - request_id: reader.context.request_id, - request: request.clone(), - }).await { - Ok(tracker) => tracker, - Err(e) => { - trace!("Request denied: {}", e); - return; - } - }; - // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - let res = reader.inner.read_to_end(0).await; - let mut writer = ProgressWriter::new(writer, reader.context, tracker); - if res.is_err() { + let tracker = context + .events + .get_request(|| GetRequestReceived { + connection_id: context.connection_id, + request_id: context.request_id, + request: request.clone(), + }) + .await?; + let mut writer = context.into_writer(tracker).await?; + if handle_get(store, request, &mut writer).await.is_ok() { + writer.transfer_completed().await; + } else { writer.transfer_aborted().await; - return; - } - match handle_get(store, request, &mut writer).await { - Ok(()) => { - writer.transfer_completed().await; - } - Err(_) => { - writer.transfer_aborted().await; - } } } Request::GetMany(request) => { - todo!(); - // // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - // reader.inner.read_to_end(0).await?; - // // move the context so we don't lose the bytes read - // writer.context = reader.context; - // handle_get_many(store, request, writer).await + let tracker = context + .events + .get_many_request(|| GetManyRequestReceived { + connection_id: context.connection_id, + request_id: context.request_id, + request: request.clone(), + }) + .await?; + let mut writer = context.into_writer(tracker).await?; + if handle_get_many(store, request, &mut writer).await.is_ok() { + writer.transfer_completed().await; + } else { + writer.transfer_aborted().await; + } } Request::Observe(request) => { - todo!(); - // // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - // reader.inner.read_to_end(0).await?; - // handle_observe(store, request, writer).await + let mut writer = context.into_writer(RequestTracker::NONE).await?; + handle_observe(store, request, &mut writer).await.ok(); } Request::Push(request) => { - todo!(); - // writer.inner.finish()?; - // handle_push(store, request, reader).await + let tracker = context + .events + .push_request(|| PushRequestReceived { + connection_id: context.connection_id, + request_id: context.request_id, + request: request.clone(), + }) + .await?; + let mut reader = context.into_reader(tracker).await?; + if handle_push(store, request, &mut reader).await.is_ok() { + reader.transfer_completed().await; + } else { + reader.transfer_aborted().await; + } } - _ => {}, + _ => {} } + Ok(()) } /// Handle a single get request. @@ -430,7 +519,7 @@ pub async fn handle_get_many( pub async fn handle_push( store: Store, request: PushRequest, - mut reader: ProgressReader, + reader: &mut ProgressReader, ) -> Result<()> { let hash = request.hash; debug!(%hash, "push received request"); @@ -616,6 +705,40 @@ pub struct ProgressReader { context: ReaderContext, } +impl ProgressReader { + pub fn new(inner: RecvStream, context: ReaderContext) -> Self { + Self { inner, context } + } + + async fn transfer_aborted(&self) { + self.tracker + .transfer_aborted(|| { + Some(Box::new(TransferStats { + payload_bytes_sent: 0, + other_bytes_sent: 0, + bytes_read: self.bytes_read, + duration: self.context.t0.elapsed(), + })) + }) + .await + .ok(); + } + + async fn transfer_completed(&self) { + self.tracker + .transfer_completed(|| { + Box::new(TransferStats { + payload_bytes_sent: 0, + other_bytes_sent: 0, + bytes_read: self.bytes_read, + duration: self.context.t0.elapsed(), + }) + }) + .await + .ok(); + } +} + impl Deref for ProgressReader { type Target = ReaderContext; diff --git a/src/provider/events.rs b/src/provider/events.rs index 55383e77c..2ae4ba5be 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -7,7 +7,10 @@ use irpc::{ use serde::{Deserialize, Serialize}; use snafu::Snafu; -use crate::{provider::{events::irpc_ext::IrpcClientExt, TransferStats}, Hash}; +use crate::{ + provider::{events::irpc_ext::IrpcClientExt, TransferStats}, + Hash, +}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[repr(u8)] @@ -168,7 +171,7 @@ impl RequestTracker { } /// A request tracker that doesn't track anything. - const NONE: Self = Self { + pub const NONE: Self = Self { updates: RequestUpdates::None, throttle: None, }; @@ -176,8 +179,12 @@ impl RequestTracker { /// Transfer for index `index` started, size `size` pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { - tx.send(RequestUpdate::Started(TransferStarted { index, hash: *hash, size })) - .await?; + tx.send(RequestUpdate::Started(TransferStarted { + index, + hash: *hash, + size, + })) + .await?; } Ok(()) } @@ -200,10 +207,7 @@ impl RequestTracker { } /// Transfer completed for the previously reported blob. - pub async fn transfer_completed( - &self, - f: impl Fn() -> Box, - ) -> irpc::Result<()> { + pub async fn transfer_completed(&self, f: impl Fn() -> Box) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { tx.send(RequestUpdate::Completed(TransferCompleted { stats: f() })) .await?; @@ -236,7 +240,10 @@ impl EventSender { }; pub fn new(client: tokio::sync::mpsc::Sender, mask: EventMask) -> Self { - Self { mask, inner: Some(irpc::Client::from(client)) } + Self { + mask, + inner: Some(irpc::Client::from(client)), + } } /// A new client has been connected. @@ -422,7 +429,11 @@ mod proto { use serde::{Deserialize, Serialize}; use super::Request; - use crate::{protocol::{ChunkRangesSeq, GetRequest}, provider::TransferStats, Hash}; + use crate::{ + protocol::{GetManyRequest, GetRequest, PushRequest}, + provider::TransferStats, + Hash, + }; #[derive(Debug, Serialize, Deserialize)] pub struct ClientConnected { @@ -458,10 +469,8 @@ mod proto { pub connection_id: u64, /// The request id. There is a new id for each request. pub request_id: u64, - /// The root hash of the request. - pub hashes: Vec, - /// The exact query ranges of the request. - pub ranges: ChunkRangesSeq, + /// The request + pub request: GetManyRequest, } impl Request for GetManyRequestReceived { @@ -476,10 +485,8 @@ mod proto { pub connection_id: u64, /// The request id. There is a new id for each request. pub request_id: u64, - /// The root hash of the request. - pub hash: Hash, - /// The exact query ranges of the request. - pub ranges: ChunkRangesSeq, + /// The request + pub request: PushRequest, } impl Request for PushRequestReceived { @@ -539,7 +546,7 @@ mod irpc_ext { use std::future::Future; use irpc::{ - channel::{mpsc, none::NoSender, oneshot}, + channel::{mpsc, none::NoSender}, Channels, RpcMessage, Service, WithChannels, }; @@ -581,7 +588,7 @@ mod irpc_ext { Ok(req_tx) } irpc::Request::Remote(remote) => { - let (s, r) = remote.write(msg).await?; + let (s, _) = remote.write(msg).await?; Ok(s.into()) } } diff --git a/src/tests.rs b/src/tests.rs index 9b825bd08..e99d0fe02 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -16,7 +16,10 @@ use crate::{ hashseq::HashSeq, net_protocol::BlobsProtocol, protocol::{ChunkRangesSeq, GetManyRequest, ObserveRequest, PushRequest}, - provider::{events::{AbortReason, RequestUpdate}, Event, EventMask, EventSender2, ProviderMessage}, + provider::{ + events::{AbortReason, RequestUpdate}, + EventMask, EventSender2, ProviderMessage, + }, store::{ fs::{ tests::{create_n0_bao, test_data, INTERESTING_SIZES}, @@ -340,11 +343,7 @@ async fn two_nodes_get_many_mem() -> TestResult<()> { fn event_handler( allowed_nodes: impl IntoIterator, -) -> ( - EventSender2, - watch::Receiver, - AbortOnDropHandle<()>, -) { +) -> (EventSender2, watch::Receiver, AbortOnDropHandle<()>) { let (count_tx, count_rx) = tokio::sync::watch::channel(0usize); let (events_tx, mut events_rx) = mpsc::channel::(16); let allowed_nodes = allowed_nodes.into_iter().collect::>(); @@ -609,7 +608,8 @@ async fn node_serve_hash_seq() -> TestResult<()> { let root_tt = store.add_bytes(hash_seq).await?; let root = root_tt.hash; let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); + let blobs = + crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); @@ -640,7 +640,8 @@ async fn node_serve_blobs() -> TestResult<()> { tts.push(store.add_bytes(test_data(size)).await?); } let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); + let blobs = + crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); @@ -682,7 +683,8 @@ async fn node_smoke(store: &Store) -> TestResult<()> { let tt = store.add_bytes(b"hello world".to_vec()).temp_tag().await?; let hash = *tt.hash(); let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = crate::net_protocol::BlobsProtocol::new(store, endpoint.clone(), EventSender2::NONE); + let blobs = + crate::net_protocol::BlobsProtocol::new(store, endpoint.clone(), EventSender2::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); From a9ac8e57cde7975e1844298703f9a17a7dd897d5 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Mon, 1 Sep 2025 15:31:41 +0200 Subject: [PATCH 06/35] Everything works --- README.md | 4 +- examples/custom-protocol.rs | 6 +- examples/mdns-discovery.rs | 4 +- examples/random_store.rs | 48 +- examples/transfer.rs | 6 +- .../store/fs/util/entity_manager.txt | 7 + src/net_protocol.rs | 10 +- src/provider.rs | 425 ++++++------------ src/provider/events.rs | 249 ++++------ src/tests.rs | 24 +- 10 files changed, 295 insertions(+), 488 deletions(-) create mode 100644 proptest-regressions/store/fs/util/entity_manager.txt diff --git a/README.md b/README.md index 2f374e8fb..1a136e44d 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ Here is a basic example of how to set up `iroh-blobs` with `iroh`: ```rust,no_run use iroh::{protocol::Router, Endpoint}; -use iroh_blobs::{store::mem::MemStore, BlobsProtocol}; +use iroh_blobs::{store::mem::MemStore, BlobsProtocol, provider::events::EventSender}; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -44,7 +44,7 @@ async fn main() -> anyhow::Result<()> { // create a protocol handler using an in-memory blob store. let store = MemStore::new(); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); // build the router let router = Router::builder(endpoint) diff --git a/examples/custom-protocol.rs b/examples/custom-protocol.rs index 6542acd18..d4d29e27f 100644 --- a/examples/custom-protocol.rs +++ b/examples/custom-protocol.rs @@ -48,7 +48,9 @@ use iroh::{ protocol::{AcceptError, ProtocolHandler, Router}, Endpoint, NodeId, }; -use iroh_blobs::{api::Store, provider::EventSender2, store::mem::MemStore, BlobsProtocol, Hash}; +use iroh_blobs::{ + api::Store, provider::events::EventSender, store::mem::MemStore, BlobsProtocol, Hash, +}; mod common; use common::{get_or_generate_secret_key, setup_logging}; @@ -100,7 +102,7 @@ async fn listen(text: Vec) -> Result<()> { proto.insert_and_index(text).await?; } // Build the iroh-blobs protocol handler, which is used to download blobs. - let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); // create a router that handles both our custom protocol and the iroh-blobs protocol. let node = Router::builder(endpoint) diff --git a/examples/mdns-discovery.rs b/examples/mdns-discovery.rs index ef5d0619c..ab11bc864 100644 --- a/examples/mdns-discovery.rs +++ b/examples/mdns-discovery.rs @@ -18,7 +18,7 @@ use clap::{Parser, Subcommand}; use iroh::{ discovery::mdns::MdnsDiscovery, protocol::Router, Endpoint, PublicKey, RelayMode, SecretKey, }; -use iroh_blobs::{provider::EventSender2, store::mem::MemStore, BlobsProtocol, Hash}; +use iroh_blobs::{provider::events::EventSender, store::mem::MemStore, BlobsProtocol, Hash}; mod common; use common::{get_or_generate_secret_key, setup_logging}; @@ -68,7 +68,7 @@ async fn accept(path: &Path) -> Result<()> { .await?; let builder = Router::builder(endpoint.clone()); let store = MemStore::new(); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); let builder = builder.accept(iroh_blobs::ALPN, blobs.clone()); let node = builder.spawn(); diff --git a/examples/random_store.rs b/examples/random_store.rs index 5bc136d41..f36017e8d 100644 --- a/examples/random_store.rs +++ b/examples/random_store.rs @@ -6,11 +6,12 @@ use iroh::{SecretKey, Watcher}; use iroh_base::ticket::NodeTicket; use iroh_blobs::{ api::downloader::Shuffled, - provider::{AbortReason, EventMask, EventSender2, ProviderMessage}, + provider::events::{AbortReason, EventMask, EventSender, ProviderMessage}, store::fs::FsStore, test::{add_hash_sequences, create_random_blobs}, HashAndFormat, }; +use irpc::RpcMessage; use n0_future::StreamExt; use rand::{rngs::StdRng, Rng, SeedableRng}; use tokio::{signal::ctrl_c, sync::mpsc}; @@ -100,8 +101,15 @@ pub fn get_or_generate_secret_key() -> Result { } } -pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, EventSender2) { +pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, EventSender) { let (tx, mut rx) = mpsc::channel(100); + fn dump_updates(mut rx: irpc::channel::mpsc::Receiver) { + tokio::spawn(async move { + while let Ok(Some(update)) = rx.recv().await { + println!("{update:?}"); + } + }); + } let dump_task = tokio::spawn(async move { while let Some(event) = rx.recv().await { match event { @@ -115,29 +123,23 @@ pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, E ProviderMessage::ConnectionClosed(msg) => { println!("{:?}", msg.inner); } - ProviderMessage::GetRequestReceived(mut msg) => { + ProviderMessage::GetRequestReceived(msg) => { println!("{:?}", msg.inner); msg.tx.send(Ok(())).await.ok(); - tokio::spawn(async move { - while let Ok(update) = msg.rx.recv().await { - info!("{update:?}"); - } - }); + dump_updates(msg.rx); } ProviderMessage::GetRequestReceivedNotify(msg) => { println!("{:?}", msg.inner); + dump_updates(msg.rx); } - ProviderMessage::GetManyRequestReceived(mut msg) => { + ProviderMessage::GetManyRequestReceived(msg) => { println!("{:?}", msg.inner); msg.tx.send(Ok(())).await.ok(); - tokio::spawn(async move { - while let Ok(update) = msg.rx.recv().await { - info!("{update:?}"); - } - }); + dump_updates(msg.rx); } ProviderMessage::GetManyRequestReceivedNotify(msg) => { println!("{:?}", msg.inner); + dump_updates(msg.rx); } ProviderMessage::PushRequestReceived(msg) => { println!("{:?}", msg.inner); @@ -147,9 +149,25 @@ pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, E Err(AbortReason::Permission) }; msg.tx.send(res).await.ok(); + dump_updates(msg.rx); } ProviderMessage::PushRequestReceivedNotify(msg) => { println!("{:?}", msg.inner); + dump_updates(msg.rx); + } + ProviderMessage::ObserveRequestReceived(msg) => { + println!("{:?}", msg.inner); + let res = if allow_push { + Ok(()) + } else { + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); + dump_updates(msg.rx); + } + ProviderMessage::ObserveRequestReceivedNotify(msg) => { + println!("{:?}", msg.inner); + dump_updates(msg.rx); } ProviderMessage::Throttle(msg) => { println!("{:?}", msg.inner); @@ -158,7 +176,7 @@ pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, E } } }); - (dump_task, EventSender2::new(tx, EventMask::ALL)) + (dump_task, EventSender::new(tx, EventMask::ALL)) } #[tokio::main] diff --git a/examples/transfer.rs b/examples/transfer.rs index baa1e343c..8347774ca 100644 --- a/examples/transfer.rs +++ b/examples/transfer.rs @@ -1,7 +1,9 @@ use std::path::PathBuf; use iroh::{protocol::Router, Endpoint}; -use iroh_blobs::{provider::EventSender2, store::mem::MemStore, ticket::BlobTicket, BlobsProtocol}; +use iroh_blobs::{ + provider::events::EventSender, store::mem::MemStore, ticket::BlobTicket, BlobsProtocol, +}; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -12,7 +14,7 @@ async fn main() -> anyhow::Result<()> { // We initialize an in-memory backing store for iroh-blobs let store = MemStore::new(); // Then we initialize a struct that can accept blobs requests over iroh connections - let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); // Grab all passed in arguments, the first one is the binary itself, so we skip it. let args: Vec = std::env::args().skip(1).collect(); diff --git a/proptest-regressions/store/fs/util/entity_manager.txt b/proptest-regressions/store/fs/util/entity_manager.txt new file mode 100644 index 000000000..94b6aa63c --- /dev/null +++ b/proptest-regressions/store/fs/util/entity_manager.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 0f2ebc49ab2f84e112f08407bb94654fbcb1f19050a4a8a6196383557696438a # shrinks to input = _TestCountersManagerProptestFsArgs { entries: [(15313427648878534792, 264348813928009031854006459208395772047), (1642534478798447378, 15989109311941500072752977306696275871), (8755041673862065815, 172763711808688570294350362332402629716), (4993597758667891804, 114145440157220458287429360639759690928), (15031383154962489250, 63217081714858286463391060323168548783), (17668469631267503333, 11878544422669770587175118199598836678), (10507570291819955314, 126584081645379643144412921692654648228), (3979008599365278329, 283717221942996985486273080647433218905), (8316838360288996639, 334043288511621783152802090833905919408), (15673798930962474157, 77551315511802713260542200115027244708), (12058791254144360414, 56638044274259821850511200885092637649), (8191628769638031337, 314181956273420400069887649110740549194), (6290369460137232066, 255779791286732775990301011955519176773), (11919824746661852269, 319400891587146831511371932480749645441), (12491631698789073154, 271279849791970841069522263758329847554), (53891048909263304, 12061234604041487609497959407391945555), (9486366498650667097, 311383186592430597410801882015456718030), (15696332331789302593, 306911490707714340526403119780178604150), (8699088947997536151, 312272624973367009520183311568498652066), (1144772544750976199, 200591877747619565555594857038887015), (5907208586200645081, 299942008952473970881666769409865744975), (3384528743842518913, 26230956866762934113564101494944411446), (13877357832690956494, 229457597607752760006918374695475345151), (2965687966026226090, 306489188264741716662410004273408761623), (13624286905717143613, 232801392956394366686194314010536008033), (3622356130274722018, 162030840677521022192355139208505458492), (17807768575470996347, 264107246314713159406963697924105744409), (5103434150074147746, 331686166459964582006209321975587627262), (5962771466034321974, 300961804728115777587520888809168362574), (2930645694242691907, 127752709774252686733969795258447263979), (16197574560597474644, 245410120683069493317132088266217906749), (12478835478062365617, 103838791113879912161511798836229961653), (5503595333662805357, 92368472243854403026472376408708548349), (18122734335129614364, 288955542597300001147753560885976966029), (12688080215989274550, 85237436689682348751672119832134138932), (4148468277722853958, 297778117327421209654837771300216669574), (8749445804640085302, 79595866493078234154562014325793780126), (12442730869682574563, 196176786402808588883611974143577417817), (6110644747049355904, 26592587989877021920275416199052685135), (5851164380497779369, 158876888501825038083692899057819261957), (9497384378514985275, 15279835675313542048650599472403150097), (10661092311826161857, 250089949043892591422587928179995867509), (10046856000675345423, 231369150063141386398059701278066296663)] } diff --git a/src/net_protocol.rs b/src/net_protocol.rs index a1d6a1f5d..aa45aa473 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -7,7 +7,7 @@ //! ```rust //! # async fn example() -> anyhow::Result<()> { //! use iroh::{protocol::Router, Endpoint}; -//! use iroh_blobs::{store, BlobsProtocol}; +//! use iroh_blobs::{provider::events::EventSender, store, BlobsProtocol}; //! //! // create a store //! let store = store::fs::FsStore::load("blobs").await?; @@ -19,7 +19,7 @@ //! let endpoint = Endpoint::builder().discovery_n0().bind().await?; //! //! // create a blobs protocol handler -//! let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); +//! let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); //! //! // create a router and add the blobs protocol handler //! let router = Router::builder(endpoint) @@ -45,13 +45,13 @@ use iroh::{ }; use tracing::error; -use crate::{api::Store, provider::EventSender2, ticket::BlobTicket, HashAndFormat}; +use crate::{api::Store, provider::events::EventSender, ticket::BlobTicket, HashAndFormat}; #[derive(Debug)] pub(crate) struct BlobsInner { pub(crate) store: Store, pub(crate) endpoint: Endpoint, - pub(crate) events: EventSender2, + pub(crate) events: EventSender, } /// A protocol handler for the blobs protocol. @@ -69,7 +69,7 @@ impl Deref for BlobsProtocol { } impl BlobsProtocol { - pub fn new(store: &Store, endpoint: Endpoint, events: EventSender2) -> Self { + pub fn new(store: &Store, endpoint: Endpoint, events: EventSender) -> Self { Self { inner: Arc::new(BlobsInner { store: store.clone(), diff --git a/src/provider.rs b/src/provider.rs index 2f7bd078f..a98c1b3a7 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -14,16 +14,12 @@ use std::{ use anyhow::{Context, Result}; use bao_tree::ChunkRanges; -use iroh::{ - endpoint::{self, RecvStream, SendStream}, - NodeId, -}; -use irpc::channel::oneshot; +use iroh::endpoint::{self, RecvStream, SendStream}; use n0_future::StreamExt; -use quinn::{ClosedStream, ReadToEndError}; +use quinn::{ClosedStream, ConnectionError, ReadToEndError}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use tokio::{io::AsyncRead, select, sync::mpsc}; -use tracing::{debug, debug_span, error, trace, warn, Instrument}; +use tokio::{io::AsyncRead, select}; +use tracing::{debug, debug_span, warn, Instrument}; use crate::{ api::{ @@ -32,112 +28,12 @@ use crate::{ Store, }, hashseq::HashSeq, - protocol::{ - ChunkRangesSeq, GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, - Request, - }, - provider::events::{ - ClientConnected, ConnectionClosed, GetManyRequestReceived, GetRequestReceived, - PushRequestReceived, RequestTracker, - }, + protocol::{GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request}, + provider::events::{ClientConnected, ClientError, ConnectionClosed, RequestTracker}, Hash, }; -pub(crate) mod events; -pub use events::{AbortReason, EventMask, EventSender as EventSender2, ProviderMessage}; - -/// Provider progress events, to keep track of what the provider is doing. -/// -/// ClientConnected -> -/// (GetRequestReceived -> (TransferStarted -> TransferProgress*n)*n -> (TransferCompleted | TransferAborted))*n -> -/// ConnectionClosed -#[derive(Debug)] -pub enum Event { - /// A new client connected to the provider. - ClientConnected { - connection_id: u64, - node_id: NodeId, - permitted: oneshot::Sender, - }, - /// Connection closed. - ConnectionClosed { connection_id: u64 }, - /// A new get request was received from the provider. - GetRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The root hash of the request. - hash: Hash, - /// The exact query ranges of the request. - ranges: ChunkRangesSeq, - }, - /// A new get request was received from the provider. - GetManyRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The root hash of the request. - hashes: Vec, - /// The exact query ranges of the request. - ranges: ChunkRangesSeq, - }, - /// A new get request was received from the provider. - PushRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The root hash of the request. - hash: Hash, - /// The exact query ranges of the request. - ranges: ChunkRangesSeq, - /// Complete this to permit the request. - permitted: oneshot::Sender, - }, - /// Transfer for the nth blob started. - TransferStarted { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The index of the blob in the request. 0 for the first blob or for raw blob requests. - index: u64, - /// The hash of the blob. This is the hash of the request for the first blob, the child hash (index-1) for subsequent blobs. - hash: Hash, - /// The size of the blob. This is the full size of the blob, not the size we are sending. - size: u64, - }, - /// Progress of the transfer. - TransferProgress { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The index of the blob in the request. 0 for the first blob or for raw blob requests. - index: u64, - /// The end offset of the chunk that was sent. - end_offset: u64, - }, - /// Entire transfer completed. - TransferCompleted { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// Statistics about the transfer. - stats: Box, - }, - /// Entire transfer aborted - TransferAborted { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// Statistics about the part of the transfer that was aborted. - stats: Option>, - }, -} +pub mod events; +use events::EventSender; /// Statistics about a successful or failed transfer. #[derive(Debug, Serialize, Deserialize)] @@ -150,8 +46,9 @@ pub struct TransferStats { pub other_bytes_sent: u64, /// The number of bytes read from the stream. /// - /// This is the size of the request. - pub bytes_read: u64, + /// In most cases this is just the request, for push requests this is + /// request, size header and hash pairs. + pub other_bytes_read: u64, /// Total duration from reading the request to transfer completed. pub duration: Duration, } @@ -167,22 +64,38 @@ pub struct TransferStats { pub async fn read_request(context: &mut StreamData) -> Result { let mut counting = CountingReader::new(&mut context.reader); let res = Request::read_async(&mut counting).await?; - context.bytes_read += counting.read(); + context.other_bytes_read += counting.read(); Ok(res) } #[derive(Debug)] pub struct StreamData { - pub t0: Instant, - pub connection_id: u64, - pub request_id: u64, - pub reader: RecvStream, - pub writer: SendStream, - pub events: EventSender2, - pub bytes_read: u64, + t0: Instant, + connection_id: u64, + request_id: u64, + reader: RecvStream, + writer: SendStream, + other_bytes_read: u64, + events: EventSender, } impl StreamData { + pub async fn accept( + conn: &endpoint::Connection, + events: &EventSender, + ) -> Result { + let (writer, reader) = conn.accept_bi().await?; + Ok(Self { + t0: Instant::now(), + connection_id: conn.stable_id() as u64, + request_id: reader.id().into(), + reader, + writer, + other_bytes_read: 0, + events: events.clone(), + }) + } + /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id async fn into_writer( mut self, @@ -190,7 +103,10 @@ impl StreamData { ) -> Result { let res = self.reader.read_to_end(0).await; if let Err(e) = res { - tracker.transfer_aborted(|| None).await.ok(); + tracker + .transfer_aborted(|| Box::new(self.stats())) + .await + .ok(); return Err(e); }; Ok(ProgressWriter::new( @@ -199,7 +115,7 @@ impl StreamData { t0: self.t0, connection_id: self.connection_id, request_id: self.request_id, - bytes_read: self.bytes_read, + other_bytes_read: self.other_bytes_read, payload_bytes_written: 0, other_bytes_written: 0, tracker, @@ -213,7 +129,10 @@ impl StreamData { ) -> Result { let res = self.writer.finish(); if let Err(e) = res { - tracker.transfer_aborted(|| None).await.ok(); + tracker + .transfer_aborted(|| Box::new(self.stats())) + .await + .ok(); return Err(e); }; Ok(ProgressReader::new( @@ -222,11 +141,56 @@ impl StreamData { t0: self.t0, connection_id: self.connection_id, request_id: self.request_id, - bytes_read: self.bytes_read, + other_bytes_read: self.other_bytes_read, tracker, }, )) } + + async fn get_request( + &self, + f: impl FnOnce() -> GetRequest, + ) -> Result { + self.events + .request(f, self.connection_id, self.request_id) + .await + } + + async fn get_many_request( + &self, + f: impl FnOnce() -> GetManyRequest, + ) -> Result { + self.events + .request(f, self.connection_id, self.request_id) + .await + } + + async fn push_request( + &self, + f: impl FnOnce() -> PushRequest, + ) -> Result { + self.events + .request(f, self.connection_id, self.request_id) + .await + } + + async fn observe_request( + &self, + f: impl FnOnce() -> ObserveRequest, + ) -> Result { + self.events + .request(f, self.connection_id, self.request_id) + .await + } + + fn stats(&self) -> TransferStats { + TransferStats { + payload_bytes_sent: 0, + other_bytes_sent: 0, + other_bytes_read: self.other_bytes_read, + duration: self.t0.elapsed(), + } + } } #[derive(Debug)] @@ -238,7 +202,7 @@ pub struct ReaderContext { /// The request ID from the recv stream pub request_id: u64, /// The number of bytes read from the stream - pub bytes_read: u64, + pub other_bytes_read: u64, /// Progress tracking for the request pub tracker: RequestTracker, } @@ -249,10 +213,19 @@ impl ReaderContext { t0: context.t0, connection_id: context.connection_id, request_id: context.request_id, - bytes_read: context.bytes_read, + other_bytes_read: context.other_bytes_read, tracker, } } + + pub fn stats(&self) -> TransferStats { + TransferStats { + payload_bytes_sent: 0, + other_bytes_sent: 0, + other_bytes_read: self.other_bytes_read, + duration: self.t0.elapsed(), + } + } } #[derive(Debug)] @@ -264,7 +237,7 @@ pub struct WriterContext { /// The request ID from the recv stream pub request_id: u64, /// The number of bytes read from the stream - pub bytes_read: u64, + pub other_bytes_read: u64, /// The number of payload bytes written to the stream pub payload_bytes_written: u64, /// The number of bytes written that are not part of the payload @@ -279,12 +252,21 @@ impl WriterContext { t0: context.t0, connection_id: context.connection_id, request_id: context.request_id, - bytes_read: context.bytes_read, + other_bytes_read: context.other_bytes_read, payload_bytes_written: 0, other_bytes_written: 0, tracker, } } + + pub fn stats(&self) -> TransferStats { + TransferStats { + payload_bytes_sent: self.payload_bytes_written, + other_bytes_sent: self.other_bytes_written, + other_bytes_read: self.other_bytes_read, + duration: self.t0.elapsed(), + } + } } /// Wrapper for a [`quinn::SendStream`] with additional per request information. @@ -302,28 +284,14 @@ impl ProgressWriter { async fn transfer_aborted(&self) { self.tracker - .transfer_aborted(|| { - Some(Box::new(TransferStats { - payload_bytes_sent: self.payload_bytes_written, - other_bytes_sent: self.other_bytes_written, - bytes_read: self.bytes_read, - duration: self.context.t0.elapsed(), - })) - }) + .transfer_aborted(|| Box::new(self.stats())) .await .ok(); } async fn transfer_completed(&self) { self.tracker - .transfer_completed(|| { - Box::new(TransferStats { - payload_bytes_sent: self.payload_bytes_written, - other_bytes_sent: self.other_bytes_written, - bytes_read: self.bytes_read, - duration: self.context.t0.elapsed(), - }) - }) + .transfer_completed(|| Box::new(self.stats())) .await .ok(); } @@ -347,7 +315,7 @@ impl DerefMut for ProgressWriter { pub async fn handle_connection( connection: endpoint::Connection, store: Store, - progress: EventSender2, + progress: EventSender, ) { let connection_id = connection.stable_id() as u64; let span = debug_span!("connection", connection_id); @@ -366,21 +334,9 @@ pub async fn handle_connection( debug!("client not authorized to connect: {cause}"); return; } - while let Ok((writer, reader)) = connection.accept_bi().await { - // The stream ID index is used to identify this request. Requests only arrive in - // bi-directional RecvStreams initiated by the client, so this uniquely identifies them. - let request_id = reader.id().index(); - let span = debug_span!("stream", stream_id = %request_id); + while let Ok(context) = StreamData::accept(&connection, &progress).await { + let span = debug_span!("stream", stream_id = %context.request_id); let store = store.clone(); - let context = StreamData { - t0: Instant::now(), - connection_id: connection_id, - request_id: request_id, - reader, - writer, - events: progress.clone(), - bytes_read: 0, - }; tokio::spawn(handle_stream(store, context).instrument(span)); } progress @@ -399,14 +355,7 @@ async fn handle_stream(store: Store, mut context: StreamData) -> anyhow::Result< match request { Request::Get(request) => { - let tracker = context - .events - .get_request(|| GetRequestReceived { - connection_id: context.connection_id, - request_id: context.request_id, - request: request.clone(), - }) - .await?; + let tracker = context.get_request(|| request.clone()).await?; let mut writer = context.into_writer(tracker).await?; if handle_get(store, request, &mut writer).await.is_ok() { writer.transfer_completed().await; @@ -415,14 +364,7 @@ async fn handle_stream(store: Store, mut context: StreamData) -> anyhow::Result< } } Request::GetMany(request) => { - let tracker = context - .events - .get_many_request(|| GetManyRequestReceived { - connection_id: context.connection_id, - request_id: context.request_id, - request: request.clone(), - }) - .await?; + let tracker = context.get_many_request(|| request.clone()).await?; let mut writer = context.into_writer(tracker).await?; if handle_get_many(store, request, &mut writer).await.is_ok() { writer.transfer_completed().await; @@ -431,18 +373,16 @@ async fn handle_stream(store: Store, mut context: StreamData) -> anyhow::Result< } } Request::Observe(request) => { - let mut writer = context.into_writer(RequestTracker::NONE).await?; - handle_observe(store, request, &mut writer).await.ok(); + let tracker = context.observe_request(|| request.clone()).await?; + let mut writer = context.into_writer(tracker).await?; + if handle_observe(store, request, &mut writer).await.is_ok() { + writer.transfer_completed().await; + } else { + writer.transfer_aborted().await; + } } Request::Push(request) => { - let tracker = context - .events - .push_request(|| PushRequestReceived { - connection_id: context.connection_id, - request_id: context.request_id, - request: request.clone(), - }) - .await?; + let tracker = context.push_request(|| request.clone()).await?; let mut reader = context.into_reader(tracker).await?; if handle_push(store, request, &mut reader).await.is_ok() { reader.transfer_completed().await; @@ -607,99 +547,6 @@ async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> Resu Ok(()) } -/// Helper to lazyly create an [`Event`], in the case that the event creation -/// is expensive and we want to avoid it if the progress sender is disabled. -pub trait LazyEvent { - fn call(self) -> Event; -} - -impl LazyEvent for T -where - T: FnOnce() -> Event, -{ - fn call(self) -> Event { - self() - } -} - -impl LazyEvent for Event { - fn call(self) -> Event { - self - } -} - -/// A sender for provider events. -#[derive(Debug, Clone)] -pub struct EventSender(EventSenderInner); - -#[derive(Debug, Clone)] -enum EventSenderInner { - Disabled, - Enabled(mpsc::Sender), -} - -impl EventSender { - pub fn new(sender: Option>) -> Self { - match sender { - Some(sender) => Self(EventSenderInner::Enabled(sender)), - None => Self(EventSenderInner::Disabled), - } - } - - /// Send a client connected event, if the progress sender is enabled. - /// - /// This will permit the client to connect if the sender is disabled. - #[must_use = "permit should be checked by the caller"] - pub async fn authorize_client_connection(&self, connection_id: u64, node_id: NodeId) -> bool { - let mut wait_for_permit = None; - self.send(|| { - let (tx, rx) = oneshot::channel(); - wait_for_permit = Some(rx); - Event::ClientConnected { - connection_id, - node_id, - permitted: tx, - } - }) - .await; - if let Some(wait_for_permit) = wait_for_permit { - // if we have events configured, and they drop the channel, we consider that as a no! - // todo: this will be confusing and needs to be properly documented. - wait_for_permit.await.unwrap_or(false) - } else { - true - } - } - - /// Send an ephemeral event, if the progress sender is enabled. - /// - /// The event will only be created if the sender is enabled. - fn try_send(&self, event: impl LazyEvent) { - match &self.0 { - EventSenderInner::Enabled(sender) => { - let value = event.call(); - sender.try_send(value).ok(); - } - EventSenderInner::Disabled => {} - } - } - - /// Send a mandatory event, if the progress sender is enabled. - /// - /// The event only be created if the sender is enabled. - async fn send(&self, event: impl LazyEvent) { - match &self.0 { - EventSenderInner::Enabled(sender) => { - let value = event.call(); - if let Err(err) = sender.send(value).await { - error!("failed to send progress event: {:?}", err); - } - } - EventSenderInner::Disabled => {} - } - } -} - pub struct ProgressReader { inner: RecvStream, context: ReaderContext, @@ -712,28 +559,14 @@ impl ProgressReader { async fn transfer_aborted(&self) { self.tracker - .transfer_aborted(|| { - Some(Box::new(TransferStats { - payload_bytes_sent: 0, - other_bytes_sent: 0, - bytes_read: self.bytes_read, - duration: self.context.t0.elapsed(), - })) - }) + .transfer_aborted(|| Box::new(self.stats())) .await .ok(); } async fn transfer_completed(&self) { self.tracker - .transfer_completed(|| { - Box::new(TransferStats { - payload_bytes_sent: 0, - other_bytes_sent: 0, - bytes_read: self.bytes_read, - duration: self.context.t0.elapsed(), - }) - }) + .transfer_completed(|| Box::new(self.stats())) .await .ok(); } diff --git a/src/provider/events.rs b/src/provider/events.rs index 2ae4ba5be..c8b94c8b8 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize}; use snafu::Snafu; use crate::{ + protocol::{GetManyRequest, GetRequest, ObserveRequest, PushRequest}, provider::{events::irpc_ext::IrpcClientExt, TransferStats}, Hash, }; @@ -24,15 +25,27 @@ pub enum ConnectMode { Request, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum ObserveMode { + /// We don't get notification of connect events at all. + #[default] + None, + /// We get a notification for connect events. + Notify, + /// We get a request for connect events and can reject incoming connections. + Request, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] #[repr(u8)] pub enum RequestMode { /// We don't get request events at all. #[default] None, - /// We get a notification for each request. + /// We get a notification for each request, but no transfer events. Notify, - /// We get a request for each request, and can reject incoming requests. + /// We get a request for each request, and can reject incoming requests, but no transfer events. Request, /// We get a notification for each request as well as detailed transfer events. NotifyLog, @@ -101,6 +114,7 @@ pub struct EventMask { get: RequestMode, get_many: RequestMode, push: RequestMode, + observe: ObserveMode, /// throttling is somewhat costly, so you can disable it completely throttle: ThrottleMode, } @@ -113,6 +127,7 @@ impl EventMask { get_many: RequestMode::None, push: RequestMode::None, throttle: ThrottleMode::None, + observe: ObserveMode::None, }; /// You get asked for every single thing that is going on and can intervene/throttle. @@ -122,6 +137,7 @@ impl EventMask { get_many: RequestMode::RequestLog, push: RequestMode::RequestLog, throttle: ThrottleMode::Throttle, + observe: ObserveMode::Request, }; /// You get notified for every single thing that is going on, but can't intervene. @@ -131,6 +147,7 @@ impl EventMask { get_many: RequestMode::NotifyLog, push: RequestMode::NotifyLog, throttle: ThrottleMode::None, + observe: ObserveMode::Notify, }; } @@ -216,10 +233,7 @@ impl RequestTracker { } /// Transfer aborted for the previously reported blob. - pub async fn transfer_aborted( - &self, - f: impl Fn() -> Option>, - ) -> irpc::Result<()> { + pub async fn transfer_aborted(&self, f: impl Fn() -> Box) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { tx.send(RequestUpdate::Aborted(TransferAborted { stats: f() })) .await?; @@ -264,108 +278,82 @@ impl EventSender { }) } - /// Start a get request. You will get back either an error if the request should not proceed, or a - /// [`RequestTracker`] that you can use to log progress for this particular request. - /// - /// Depending on the event sender config, the returned tracker might be a no-op. - pub async fn get_request( - &self, - f: impl FnOnce() -> GetRequestReceived, - ) -> Result { - self.request(f).await - } - - // Start a get_many request. You will get back either an error if the request should not proceed, or a - /// [`RequestTracker`] that you can use to log progress for this particular request. - /// - /// Depending on the event sender config, the returned tracker might be a no-op. - pub async fn get_many_request( - &self, - f: impl FnOnce() -> GetManyRequestReceived, - ) -> Result { - self.request(f).await - } - - // Start a push request. You will get back either an error if the request should not proceed, or a - /// [`RequestTracker`] that you can use to log progress for this particular request. - /// - /// Depending on the event sender config, the returned tracker might be a no-op. - pub async fn push_request( - &self, - f: impl FnOnce() -> PushRequestReceived, - ) -> Result { - self.request(f).await - } - /// Abstract request, to DRY the 3 to 4 request types. /// /// DRYing stuff with lots of bounds is no fun at all... - async fn request(&self, f: impl FnOnce() -> Req) -> Result + pub(crate) async fn request( + &self, + f: impl FnOnce() -> Req, + connection_id: u64, + request_id: u64, + ) -> Result where - Req: Request, - ProviderProto: From, - ProviderMessage: From>, - Req: Channels< + ProviderProto: From>, + ProviderMessage: From, ProviderProto>>, + RequestReceived: Channels< ProviderProto, Tx = oneshot::Sender, Rx = mpsc::Receiver, >, - ProviderProto: From>, - ProviderMessage: From, ProviderProto>>, - Notify: Channels>, + ProviderProto: From>>, + ProviderMessage: From>, ProviderProto>>, + Notify>: + Channels>, { - Ok(self.into_tracker(if let Some(client) = &self.inner { - match self.mask.get { - RequestMode::None => { - if self.mask.throttle == ThrottleMode::Throttle { - // if throttling is enabled, we need to call f to get connection_id and request_id - let msg = f(); - (RequestUpdates::None, msg.id()) - } else { - (RequestUpdates::None, (0, 0)) + Ok(self.into_tracker(( + if let Some(client) = &self.inner { + match self.mask.get { + RequestMode::None => RequestUpdates::None, + RequestMode::Notify => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + RequestUpdates::Disabled(client.notify_streaming(Notify(msg), 32).await?) + } + RequestMode::Request => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + let (tx, rx) = client.client_streaming(msg, 32).await?; + // bail out if the request is not allowed + rx.await??; + RequestUpdates::Disabled(tx) + } + RequestMode::NotifyLog => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + RequestUpdates::Active(client.notify_streaming(Notify(msg), 32).await?) + } + RequestMode::RequestLog => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + let (tx, rx) = client.client_streaming(msg, 32).await?; + // bail out if the request is not allowed + rx.await??; + RequestUpdates::Active(tx) } } - RequestMode::Notify => { - let msg = f(); - let id = msg.id(); - ( - RequestUpdates::Disabled(client.notify_streaming(Notify(msg), 32).await?), - id, - ) - } - RequestMode::Request => { - let msg = f(); - let id = msg.id(); - let (tx, rx) = client.client_streaming(msg, 32).await?; - // bail out if the request is not allowed - rx.await??; - (RequestUpdates::Disabled(tx), id) - } - RequestMode::NotifyLog => { - let msg = f(); - let id = msg.id(); - ( - RequestUpdates::Active(client.notify_streaming(Notify(msg), 32).await?), - id, - ) - } - RequestMode::RequestLog => { - let msg = f(); - let id = msg.id(); - let (tx, rx) = client.client_streaming(msg, 32).await?; - // bail out if the request is not allowed - rx.await??; - (RequestUpdates::Active(tx), id) - } - } - } else { - (RequestUpdates::None, (0, 0)) - })) + } else { + RequestUpdates::None + }, + connection_id, + request_id, + ))) } fn into_tracker( &self, - (updates, (connection_id, request_id)): (RequestUpdates, (u64, u64)), + (updates, connection_id, request_id): (RequestUpdates, u64, u64), ) -> RequestTracker { let throttle = match self.mask.throttle { ThrottleMode::None => None, @@ -394,46 +382,45 @@ pub enum ProviderProto { #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] /// A new get request was received from the provider. - GetRequestReceived(GetRequestReceived), + GetRequestReceived(RequestReceived), #[rpc(rx = mpsc::Receiver, tx = NoSender)] /// A new get request was received from the provider. - GetRequestReceivedNotify(Notify), + GetRequestReceivedNotify(Notify>), /// A new get request was received from the provider. #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] - GetManyRequestReceived(GetManyRequestReceived), + GetManyRequestReceived(RequestReceived), /// A new get request was received from the provider. #[rpc(rx = mpsc::Receiver, tx = NoSender)] - GetManyRequestReceivedNotify(Notify), + GetManyRequestReceivedNotify(Notify>), /// A new get request was received from the provider. #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] - PushRequestReceived(PushRequestReceived), + PushRequestReceived(RequestReceived), /// A new get request was received from the provider. #[rpc(rx = mpsc::Receiver, tx = NoSender)] - PushRequestReceivedNotify(Notify), + PushRequestReceivedNotify(Notify>), + + /// A new get request was received from the provider. + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] + ObserveRequestReceived(RequestReceived), + + /// A new get request was received from the provider. + #[rpc(rx = mpsc::Receiver, tx = NoSender)] + ObserveRequestReceivedNotify(Notify>), #[rpc(tx = oneshot::Sender)] Throttle(Throttle), } -trait Request { - fn id(&self) -> (u64, u64); -} - mod proto { use iroh::NodeId; use serde::{Deserialize, Serialize}; - use super::Request; - use crate::{ - protocol::{GetManyRequest, GetRequest, PushRequest}, - provider::TransferStats, - Hash, - }; + use crate::{provider::TransferStats, Hash}; #[derive(Debug, Serialize, Deserialize)] pub struct ClientConnected { @@ -448,51 +435,13 @@ mod proto { /// A new get request was received from the provider. #[derive(Debug, Serialize, Deserialize)] - pub struct GetRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// The request - pub request: GetRequest, - } - - impl Request for GetRequestReceived { - fn id(&self) -> (u64, u64) { - (self.connection_id, self.request_id) - } - } - - #[derive(Debug, Serialize, Deserialize)] - pub struct GetManyRequestReceived { + pub struct RequestReceived { /// The connection id. Multiple requests can be sent over the same connection. pub connection_id: u64, /// The request id. There is a new id for each request. pub request_id: u64, /// The request - pub request: GetManyRequest, - } - - impl Request for GetManyRequestReceived { - fn id(&self) -> (u64, u64) { - (self.connection_id, self.request_id) - } - } - - #[derive(Debug, Serialize, Deserialize)] - pub struct PushRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - pub connection_id: u64, - /// The request id. There is a new id for each request. - pub request_id: u64, - /// The request - pub request: PushRequest, - } - - impl Request for PushRequestReceived { - fn id(&self) -> (u64, u64) { - (self.connection_id, self.request_id) - } + pub request: R, } /// Request to throttle sending for a specific request. @@ -524,7 +473,7 @@ mod proto { #[derive(Debug, Serialize, Deserialize)] pub struct TransferAborted { - pub stats: Option>, + pub stats: Box, } /// Stream of updates for a single request diff --git a/src/tests.rs b/src/tests.rs index e99d0fe02..0dda88fee 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -16,10 +16,7 @@ use crate::{ hashseq::HashSeq, net_protocol::BlobsProtocol, protocol::{ChunkRangesSeq, GetManyRequest, ObserveRequest, PushRequest}, - provider::{ - events::{AbortReason, RequestUpdate}, - EventMask, EventSender2, ProviderMessage, - }, + provider::events::{AbortReason, EventMask, EventSender, ProviderMessage, RequestUpdate}, store::{ fs::{ tests::{create_n0_bao, test_data, INTERESTING_SIZES}, @@ -343,7 +340,7 @@ async fn two_nodes_get_many_mem() -> TestResult<()> { fn event_handler( allowed_nodes: impl IntoIterator, -) -> (EventSender2, watch::Receiver, AbortOnDropHandle<()>) { +) -> (EventSender, watch::Receiver, AbortOnDropHandle<()>) { let (count_tx, count_rx) = tokio::sync::watch::channel(0usize); let (events_tx, mut events_rx) = mpsc::channel::(16); let allowed_nodes = allowed_nodes.into_iter().collect::>(); @@ -373,7 +370,7 @@ fn event_handler( } } })); - (EventSender2::new(events_tx, EventMask::ALL), count_rx, task) + (EventSender::new(events_tx, EventMask::ALL), count_rx, task) } async fn two_nodes_push_blobs( @@ -488,12 +485,12 @@ async fn check_presence(store: &Store, sizes: &[usize]) -> TestResult<()> { } pub async fn node_test_setup_fs(db_path: PathBuf) -> TestResult<(Router, FsStore, PathBuf)> { - node_test_setup_with_events_fs(db_path, EventSender2::NONE).await + node_test_setup_with_events_fs(db_path, EventSender::NONE).await } pub async fn node_test_setup_with_events_fs( db_path: PathBuf, - events: EventSender2, + events: EventSender, ) -> TestResult<(Router, FsStore, PathBuf)> { let store = crate::store::fs::FsStore::load(&db_path).await?; let ep = Endpoint::builder().bind().await?; @@ -503,11 +500,11 @@ pub async fn node_test_setup_with_events_fs( } pub async fn node_test_setup_mem() -> TestResult<(Router, MemStore)> { - node_test_setup_with_events_mem(EventSender2::NONE).await + node_test_setup_with_events_mem(EventSender::NONE).await } pub async fn node_test_setup_with_events_mem( - events: EventSender2, + events: EventSender, ) -> TestResult<(Router, MemStore)> { let store = MemStore::new(); let ep = Endpoint::builder().bind().await?; @@ -609,7 +606,7 @@ async fn node_serve_hash_seq() -> TestResult<()> { let root = root_tt.hash; let endpoint = Endpoint::builder().discovery_n0().bind().await?; let blobs = - crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); + crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); @@ -641,7 +638,7 @@ async fn node_serve_blobs() -> TestResult<()> { } let endpoint = Endpoint::builder().discovery_n0().bind().await?; let blobs = - crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender2::NONE); + crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); @@ -683,8 +680,7 @@ async fn node_smoke(store: &Store) -> TestResult<()> { let tt = store.add_bytes(b"hello world".to_vec()).temp_tag().await?; let hash = *tt.hash(); let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = - crate::net_protocol::BlobsProtocol::new(store, endpoint.clone(), EventSender2::NONE); + let blobs = crate::net_protocol::BlobsProtocol::new(store, endpoint.clone(), EventSender::NONE); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); From 1e4a58161f320b56e55bc90841f943c5d5676579 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Mon, 1 Sep 2025 15:35:53 +0200 Subject: [PATCH 07/35] minimize diff and required changes --- README.md | 4 ++-- examples/custom-protocol.rs | 6 ++---- examples/mdns-discovery.rs | 4 ++-- examples/random_store.rs | 2 +- examples/transfer.rs | 6 ++---- src/net_protocol.rs | 8 ++++---- src/tests.rs | 12 +++++------- 7 files changed, 18 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 1a136e44d..2f374e8fb 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ Here is a basic example of how to set up `iroh-blobs` with `iroh`: ```rust,no_run use iroh::{protocol::Router, Endpoint}; -use iroh_blobs::{store::mem::MemStore, BlobsProtocol, provider::events::EventSender}; +use iroh_blobs::{store::mem::MemStore, BlobsProtocol}; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -44,7 +44,7 @@ async fn main() -> anyhow::Result<()> { // create a protocol handler using an in-memory blob store. let store = MemStore::new(); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); // build the router let router = Router::builder(endpoint) diff --git a/examples/custom-protocol.rs b/examples/custom-protocol.rs index d4d29e27f..c021b7f0a 100644 --- a/examples/custom-protocol.rs +++ b/examples/custom-protocol.rs @@ -48,9 +48,7 @@ use iroh::{ protocol::{AcceptError, ProtocolHandler, Router}, Endpoint, NodeId, }; -use iroh_blobs::{ - api::Store, provider::events::EventSender, store::mem::MemStore, BlobsProtocol, Hash, -}; +use iroh_blobs::{api::Store, store::mem::MemStore, BlobsProtocol, Hash}; mod common; use common::{get_or_generate_secret_key, setup_logging}; @@ -102,7 +100,7 @@ async fn listen(text: Vec) -> Result<()> { proto.insert_and_index(text).await?; } // Build the iroh-blobs protocol handler, which is used to download blobs. - let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); // create a router that handles both our custom protocol and the iroh-blobs protocol. let node = Router::builder(endpoint) diff --git a/examples/mdns-discovery.rs b/examples/mdns-discovery.rs index ab11bc864..b42f88f47 100644 --- a/examples/mdns-discovery.rs +++ b/examples/mdns-discovery.rs @@ -18,7 +18,7 @@ use clap::{Parser, Subcommand}; use iroh::{ discovery::mdns::MdnsDiscovery, protocol::Router, Endpoint, PublicKey, RelayMode, SecretKey, }; -use iroh_blobs::{provider::events::EventSender, store::mem::MemStore, BlobsProtocol, Hash}; +use iroh_blobs::{store::mem::MemStore, BlobsProtocol, Hash}; mod common; use common::{get_or_generate_secret_key, setup_logging}; @@ -68,7 +68,7 @@ async fn accept(path: &Path) -> Result<()> { .await?; let builder = Router::builder(endpoint.clone()); let store = MemStore::new(); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); let builder = builder.accept(iroh_blobs::ALPN, blobs.clone()); let node = builder.spawn(); diff --git a/examples/random_store.rs b/examples/random_store.rs index f36017e8d..f23e804e1 100644 --- a/examples/random_store.rs +++ b/examples/random_store.rs @@ -238,7 +238,7 @@ async fn provide(args: ProvideArgs) -> anyhow::Result<()> { .bind() .await?; let (dump_task, events_tx) = dump_provider_events(args.allow_push); - let blobs = iroh_blobs::BlobsProtocol::new(&store, endpoint.clone(), events_tx); + let blobs = iroh_blobs::BlobsProtocol::new(&store, endpoint.clone(), Some(events_tx)); let router = iroh::protocol::Router::builder(endpoint.clone()) .accept(iroh_blobs::ALPN, blobs) .spawn(); diff --git a/examples/transfer.rs b/examples/transfer.rs index 8347774ca..48fba6ba3 100644 --- a/examples/transfer.rs +++ b/examples/transfer.rs @@ -1,9 +1,7 @@ use std::path::PathBuf; use iroh::{protocol::Router, Endpoint}; -use iroh_blobs::{ - provider::events::EventSender, store::mem::MemStore, ticket::BlobTicket, BlobsProtocol, -}; +use iroh_blobs::{store::mem::MemStore, ticket::BlobTicket, BlobsProtocol}; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -14,7 +12,7 @@ async fn main() -> anyhow::Result<()> { // We initialize an in-memory backing store for iroh-blobs let store = MemStore::new(); // Then we initialize a struct that can accept blobs requests over iroh connections - let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); // Grab all passed in arguments, the first one is the binary itself, so we skip it. let args: Vec = std::env::args().skip(1).collect(); diff --git a/src/net_protocol.rs b/src/net_protocol.rs index aa45aa473..1927cd23d 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -7,7 +7,7 @@ //! ```rust //! # async fn example() -> anyhow::Result<()> { //! use iroh::{protocol::Router, Endpoint}; -//! use iroh_blobs::{provider::events::EventSender, store, BlobsProtocol}; +//! use iroh_blobs::{store, BlobsProtocol}; //! //! // create a store //! let store = store::fs::FsStore::load("blobs").await?; @@ -19,7 +19,7 @@ //! let endpoint = Endpoint::builder().discovery_n0().bind().await?; //! //! // create a blobs protocol handler -//! let blobs = BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); +//! let blobs = BlobsProtocol::new(&store, endpoint.clone(), None); //! //! // create a router and add the blobs protocol handler //! let router = Router::builder(endpoint) @@ -69,12 +69,12 @@ impl Deref for BlobsProtocol { } impl BlobsProtocol { - pub fn new(store: &Store, endpoint: Endpoint, events: EventSender) -> Self { + pub fn new(store: &Store, endpoint: Endpoint, events: Option) -> Self { Self { inner: Arc::new(BlobsInner { store: store.clone(), endpoint, - events, + events: events.unwrap_or(EventSender::NONE), }), } } diff --git a/src/tests.rs b/src/tests.rs index 0dda88fee..911a5cfd2 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -494,7 +494,7 @@ pub async fn node_test_setup_with_events_fs( ) -> TestResult<(Router, FsStore, PathBuf)> { let store = crate::store::fs::FsStore::load(&db_path).await?; let ep = Endpoint::builder().bind().await?; - let blobs = BlobsProtocol::new(&store, ep.clone(), events); + let blobs = BlobsProtocol::new(&store, ep.clone(), Some(events)); let router = Router::builder(ep).accept(crate::ALPN, blobs).spawn(); Ok((router, store, db_path)) } @@ -508,7 +508,7 @@ pub async fn node_test_setup_with_events_mem( ) -> TestResult<(Router, MemStore)> { let store = MemStore::new(); let ep = Endpoint::builder().bind().await?; - let blobs = BlobsProtocol::new(&store, ep.clone(), events); + let blobs = BlobsProtocol::new(&store, ep.clone(), Some(events)); let router = Router::builder(ep).accept(crate::ALPN, blobs).spawn(); Ok((router, store)) } @@ -605,8 +605,7 @@ async fn node_serve_hash_seq() -> TestResult<()> { let root_tt = store.add_bytes(hash_seq).await?; let root = root_tt.hash; let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = - crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); + let blobs = crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), None); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); @@ -637,8 +636,7 @@ async fn node_serve_blobs() -> TestResult<()> { tts.push(store.add_bytes(test_data(size)).await?); } let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = - crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), EventSender::NONE); + let blobs = crate::net_protocol::BlobsProtocol::new(&store, endpoint.clone(), None); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); @@ -680,7 +678,7 @@ async fn node_smoke(store: &Store) -> TestResult<()> { let tt = store.add_bytes(b"hello world".to_vec()).temp_tag().await?; let hash = *tt.hash(); let endpoint = Endpoint::builder().discovery_n0().bind().await?; - let blobs = crate::net_protocol::BlobsProtocol::new(store, endpoint.clone(), EventSender::NONE); + let blobs = crate::net_protocol::BlobsProtocol::new(store, endpoint.clone(), None); let r1 = Router::builder(endpoint) .accept(crate::protocol::ALPN, blobs) .spawn(); From 64499308baebe3c51f0e38a6a2ab8e986f06f56b Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Mon, 1 Sep 2025 17:12:59 +0200 Subject: [PATCH 08/35] clippy --- src/api/blobs.rs | 17 --- src/net_protocol.rs | 2 +- src/protocol.rs | 47 +++++---- src/provider.rs | 235 ++++++++++++++--------------------------- src/provider/events.rs | 151 +++++++++++++------------- src/tests.rs | 4 +- 6 files changed, 192 insertions(+), 264 deletions(-) diff --git a/src/api/blobs.rs b/src/api/blobs.rs index d00a0a940..8b618de1f 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -57,7 +57,6 @@ use super::{ }; use crate::{ api::proto::{BatchRequest, ImportByteStreamUpdate}, - provider::WriterContext, store::IROH_BLOCK_SIZE, util::temp_tag::TempTag, BlobFormat, Hash, HashAndFormat, @@ -1167,19 +1166,3 @@ pub(crate) trait WriteProgress { /// Notify the progress writer that a transfer has started. async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64); } - -impl WriteProgress for WriterContext { - async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) { - let end_offset = offset + len as u64; - self.payload_bytes_written += len as u64; - self.tracker.transfer_progress(end_offset).await.ok(); - } - - fn log_other_write(&mut self, len: usize) { - self.other_bytes_written += len as u64; - } - - async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) { - self.tracker.transfer_started(index, hash, size).await.ok(); - } -} diff --git a/src/net_protocol.rs b/src/net_protocol.rs index 1927cd23d..269ef0e14 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -74,7 +74,7 @@ impl BlobsProtocol { inner: Arc::new(BlobsInner { store: store.clone(), endpoint, - events: events.unwrap_or(EventSender::NONE), + events: events.unwrap_or(EventSender::DEFAULT), }), } } diff --git a/src/protocol.rs b/src/protocol.rs index 74e0f986d..05ee00678 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -392,7 +392,7 @@ pub use range_spec::{ChunkRangesSeq, NonEmptyRequestRangeSpecIter, RangeSpec}; use snafu::{GenerateImplicitData, Snafu}; use tokio::io::AsyncReadExt; -use crate::{api::blobs::Bitfield, provider::CountingReader, BlobFormat, Hash, HashAndFormat}; +use crate::{api::blobs::Bitfield, provider::RecvStreamExt, BlobFormat, Hash, HashAndFormat}; /// Maximum message size is limited to 100MiB for now. pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; @@ -441,9 +441,7 @@ pub enum RequestType { } impl Request { - pub async fn read_async( - reader: &mut CountingReader<&mut iroh::endpoint::RecvStream>, - ) -> io::Result { + pub async fn read_async(reader: &mut iroh::endpoint::RecvStream) -> io::Result<(Self, usize)> { let request_type = reader.read_u8().await?; let request_type: RequestType = postcard::from_bytes(std::slice::from_ref(&request_type)) .map_err(|_| { @@ -453,22 +451,31 @@ impl Request { ) })?; Ok(match request_type { - RequestType::Get => reader - .read_to_end_as::(MAX_MESSAGE_SIZE) - .await? - .into(), - RequestType::GetMany => reader - .read_to_end_as::(MAX_MESSAGE_SIZE) - .await? - .into(), - RequestType::Observe => reader - .read_to_end_as::(MAX_MESSAGE_SIZE) - .await? - .into(), - RequestType::Push => reader - .read_length_prefixed::(MAX_MESSAGE_SIZE) - .await? - .into(), + RequestType::Get => { + let (r, size) = reader + .read_to_end_as::(MAX_MESSAGE_SIZE) + .await?; + (r.into(), size) + } + RequestType::GetMany => { + let (r, size) = reader + .read_to_end_as::(MAX_MESSAGE_SIZE) + .await?; + (r.into(), size) + } + RequestType::Observe => { + let (r, size) = reader + .read_to_end_as::(MAX_MESSAGE_SIZE) + .await?; + (r.into(), size) + } + RequestType::Push => { + let r = reader + .read_length_prefixed::(MAX_MESSAGE_SIZE) + .await?; + let size = postcard::experimental::serialized_size(&r).unwrap(); + (r.into(), size) + } _ => { return Err(io::Error::new( io::ErrorKind::InvalidData, diff --git a/src/provider.rs b/src/provider.rs index a98c1b3a7..1683daa57 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -6,9 +6,6 @@ use std::{ fmt::Debug, io, - ops::{Deref, DerefMut}, - pin::Pin, - task::Poll, time::{Duration, Instant}, }; @@ -18,7 +15,7 @@ use iroh::endpoint::{self, RecvStream, SendStream}; use n0_future::StreamExt; use quinn::{ClosedStream, ConnectionError, ReadToEndError}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use tokio::{io::AsyncRead, select}; +use tokio::select; use tracing::{debug, debug_span, warn, Instrument}; use crate::{ @@ -53,23 +50,9 @@ pub struct TransferStats { pub duration: Duration, } -/// Read the request from the getter. -/// -/// Will fail if there is an error while reading, or if no valid request is sent. -/// -/// This will read exactly the number of bytes needed for the request, and -/// leave the rest of the stream for the caller to read. -/// -/// It is up to the caller do decide if there should be more data. -pub async fn read_request(context: &mut StreamData) -> Result { - let mut counting = CountingReader::new(&mut context.reader); - let res = Request::read_async(&mut counting).await?; - context.other_bytes_read += counting.read(); - Ok(res) -} - +/// A pair of [`SendStream`] and [`RecvStream`] with additional context data. #[derive(Debug)] -pub struct StreamData { +pub struct StreamPair { t0: Instant, connection_id: u64, request_id: u64, @@ -79,7 +62,7 @@ pub struct StreamData { events: EventSender, } -impl StreamData { +impl StreamPair { pub async fn accept( conn: &endpoint::Connection, events: &EventSender, @@ -96,8 +79,22 @@ impl StreamData { }) } + /// Read the request. + /// + /// Will fail if there is an error while reading, or if no valid request is sent. + /// + /// This will read exactly the number of bytes needed for the request, and + /// leave the rest of the stream for the caller to read. + /// + /// It is up to the caller do decide if there should be more data. + pub async fn read_request(&mut self) -> Result { + let (res, size) = Request::read_async(&mut self.reader).await?; + self.other_bytes_read += size as u64; + Ok(res) + } + /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id - async fn into_writer( + pub async fn into_writer( mut self, tracker: RequestTracker, ) -> Result { @@ -113,8 +110,6 @@ impl StreamData { self.writer, WriterContext { t0: self.t0, - connection_id: self.connection_id, - request_id: self.request_id, other_bytes_read: self.other_bytes_read, payload_bytes_written: 0, other_bytes_written: 0, @@ -123,7 +118,7 @@ impl StreamData { )) } - async fn into_reader( + pub async fn into_reader( mut self, tracker: RequestTracker, ) -> Result { @@ -135,19 +130,17 @@ impl StreamData { .ok(); return Err(e); }; - Ok(ProgressReader::new( - self.reader, - ReaderContext { + Ok(ProgressReader { + inner: self.reader, + context: ReaderContext { t0: self.t0, - connection_id: self.connection_id, - request_id: self.request_id, other_bytes_read: self.other_bytes_read, tracker, }, - )) + }) } - async fn get_request( + pub async fn get_request( &self, f: impl FnOnce() -> GetRequest, ) -> Result { @@ -156,7 +149,7 @@ impl StreamData { .await } - async fn get_many_request( + pub async fn get_many_request( &self, f: impl FnOnce() -> GetManyRequest, ) -> Result { @@ -165,7 +158,7 @@ impl StreamData { .await } - async fn push_request( + pub async fn push_request( &self, f: impl FnOnce() -> PushRequest, ) -> Result { @@ -174,7 +167,7 @@ impl StreamData { .await } - async fn observe_request( + pub async fn observe_request( &self, f: impl FnOnce() -> ObserveRequest, ) -> Result { @@ -194,31 +187,17 @@ impl StreamData { } #[derive(Debug)] -pub struct ReaderContext { +struct ReaderContext { /// The start time of the transfer - pub t0: Instant, - /// The connection ID from the connection - pub connection_id: u64, - /// The request ID from the recv stream - pub request_id: u64, + t0: Instant, /// The number of bytes read from the stream - pub other_bytes_read: u64, + other_bytes_read: u64, /// Progress tracking for the request - pub tracker: RequestTracker, + tracker: RequestTracker, } impl ReaderContext { - pub fn new(context: StreamData, tracker: RequestTracker) -> Self { - Self { - t0: context.t0, - connection_id: context.connection_id, - request_id: context.request_id, - other_bytes_read: context.other_bytes_read, - tracker, - } - } - - pub fn stats(&self) -> TransferStats { + fn stats(&self) -> TransferStats { TransferStats { payload_bytes_sent: 0, other_bytes_sent: 0, @@ -229,37 +208,21 @@ impl ReaderContext { } #[derive(Debug)] -pub struct WriterContext { +pub(crate) struct WriterContext { /// The start time of the transfer - pub t0: Instant, - /// The connection ID from the connection - pub connection_id: u64, - /// The request ID from the recv stream - pub request_id: u64, + t0: Instant, /// The number of bytes read from the stream - pub other_bytes_read: u64, + other_bytes_read: u64, /// The number of payload bytes written to the stream - pub payload_bytes_written: u64, + payload_bytes_written: u64, /// The number of bytes written that are not part of the payload - pub other_bytes_written: u64, + other_bytes_written: u64, /// Way to report progress - pub tracker: RequestTracker, + tracker: RequestTracker, } impl WriterContext { - pub fn new(context: &StreamData, tracker: RequestTracker) -> Self { - Self { - t0: context.t0, - connection_id: context.connection_id, - request_id: context.request_id, - other_bytes_read: context.other_bytes_read, - payload_bytes_written: 0, - other_bytes_written: 0, - tracker, - } - } - - pub fn stats(&self) -> TransferStats { + fn stats(&self) -> TransferStats { TransferStats { payload_bytes_sent: self.payload_bytes_written, other_bytes_sent: self.other_bytes_written, @@ -269,6 +232,22 @@ impl WriterContext { } } +impl WriteProgress for WriterContext { + async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) { + let end_offset = offset + len as u64; + self.payload_bytes_written += len as u64; + self.tracker.transfer_progress(end_offset).await.ok(); + } + + fn log_other_write(&mut self, len: usize) { + self.other_bytes_written += len as u64; + } + + async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) { + self.tracker.transfer_started(index, hash, size).await.ok(); + } +} + /// Wrapper for a [`quinn::SendStream`] with additional per request information. #[derive(Debug)] pub struct ProgressWriter { @@ -283,34 +262,22 @@ impl ProgressWriter { } async fn transfer_aborted(&self) { - self.tracker - .transfer_aborted(|| Box::new(self.stats())) + self.context + .tracker + .transfer_aborted(|| Box::new(self.context.stats())) .await .ok(); } async fn transfer_completed(&self) { - self.tracker - .transfer_completed(|| Box::new(self.stats())) + self.context + .tracker + .transfer_completed(|| Box::new(self.context.stats())) .await .ok(); } } -impl Deref for ProgressWriter { - type Target = WriterContext; - - fn deref(&self) -> &Self::Target { - &self.context - } -} - -impl DerefMut for ProgressWriter { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.context - } -} - /// Handle a single connection. pub async fn handle_connection( connection: endpoint::Connection, @@ -334,7 +301,7 @@ pub async fn handle_connection( debug!("client not authorized to connect: {cause}"); return; } - while let Ok(context) = StreamData::accept(&connection, &progress).await { + while let Ok(context) = StreamPair::accept(&connection, &progress).await { let span = debug_span!("stream", stream_id = %context.request_id); let store = store.clone(); tokio::spawn(handle_stream(store, context).instrument(span)); @@ -348,10 +315,10 @@ pub async fn handle_connection( .await } -async fn handle_stream(store: Store, mut context: StreamData) -> anyhow::Result<()> { +async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result<()> { // 1. Decode the request. debug!("reading request"); - let request = read_request(&mut context).await?; + let request = context.read_request().await?; match request { Request::Get(request) => { @@ -553,79 +520,41 @@ pub struct ProgressReader { } impl ProgressReader { - pub fn new(inner: RecvStream, context: ReaderContext) -> Self { - Self { inner, context } - } - async fn transfer_aborted(&self) { - self.tracker - .transfer_aborted(|| Box::new(self.stats())) + self.context + .tracker + .transfer_aborted(|| Box::new(self.context.stats())) .await .ok(); } async fn transfer_completed(&self) { - self.tracker - .transfer_completed(|| Box::new(self.stats())) + self.context + .tracker + .transfer_completed(|| Box::new(self.context.stats())) .await .ok(); } } -impl Deref for ProgressReader { - type Target = ReaderContext; - - fn deref(&self) -> &Self::Target { - &self.context - } +pub(crate) trait RecvStreamExt { + async fn read_to_end_as( + &mut self, + max_size: usize, + ) -> io::Result<(T, usize)>; } -impl DerefMut for ProgressReader { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.context - } -} - -pub struct CountingReader { - pub inner: R, - pub read: u64, -} - -impl CountingReader { - pub fn new(inner: R) -> Self { - Self { inner, read: 0 } - } - - pub fn read(&self) -> u64 { - self.read - } -} - -impl CountingReader<&mut iroh::endpoint::RecvStream> { - pub async fn read_to_end_as(&mut self, max_size: usize) -> io::Result { +impl RecvStreamExt for RecvStream { + async fn read_to_end_as( + &mut self, + max_size: usize, + ) -> io::Result<(T, usize)> { let data = self - .inner .read_to_end(max_size) .await .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; let value = postcard::from_bytes(&data) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - self.read += data.len() as u64; - Ok(value) - } -} - -impl AsyncRead for CountingReader { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - let this = self.get_mut(); - let result = Pin::new(&mut this.inner).poll_read(cx, buf); - if let Poll::Ready(Ok(())) = result { - this.read += buf.filled().len() as u64; - } - result + Ok((value, data.len())) } } diff --git a/src/provider/events.rs b/src/provider/events.rs index c8b94c8b8..1ecf13cb7 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -52,6 +52,8 @@ pub enum RequestMode { /// We get a request for each request, and can reject incoming requests. /// We also get detailed transfer events. RequestLog, + /// This request type is completely disabled. All requests will be rejected. + Disabled, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] @@ -108,24 +110,35 @@ impl From for ClientError { pub type EventResult = Result<(), AbortReason>; pub type ClientResult = Result<(), ClientError>; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct EventMask { - connected: ConnectMode, - get: RequestMode, - get_many: RequestMode, - push: RequestMode, - observe: ObserveMode, + /// Connection event mask + pub connected: ConnectMode, + /// Get request event mask + pub get: RequestMode, + /// Get many request event mask + pub get_many: RequestMode, + /// Push request event mask + pub push: RequestMode, + /// Observe request event mask + pub observe: ObserveMode, /// throttling is somewhat costly, so you can disable it completely - throttle: ThrottleMode, + pub throttle: ThrottleMode, +} + +impl Default for EventMask { + fn default() -> Self { + Self::DEFAULT + } } impl EventMask { - /// Everything is disabled. You won't get any events, but there is also no runtime cost. - pub const NONE: Self = Self { + /// All event notifications are fully disabled. Push requests are disabled by default. + pub const DEFAULT: Self = Self { connected: ConnectMode::None, get: RequestMode::None, get_many: RequestMode::None, - push: RequestMode::None, + push: RequestMode::Disabled, throttle: ThrottleMode::None, observe: ObserveMode::None, }; @@ -139,16 +152,6 @@ impl EventMask { throttle: ThrottleMode::Throttle, observe: ObserveMode::Request, }; - - /// You get notified for every single thing that is going on, but can't intervene. - pub const NOTIFY_ALL: Self = Self { - connected: ConnectMode::Notify, - get: RequestMode::NotifyLog, - get_many: RequestMode::NotifyLog, - push: RequestMode::NotifyLog, - throttle: ThrottleMode::None, - observe: ObserveMode::Notify, - }; } /// Newtype wrapper that wraps an event so that it is a distinct type for the notify variant. @@ -248,8 +251,8 @@ impl RequestTracker { /// can have a response. impl EventSender { /// A client that does not send anything. - pub const NONE: Self = Self { - mask: EventMask::NONE, + pub const DEFAULT: Self = Self { + mask: EventMask::DEFAULT, inner: None, }; @@ -262,20 +265,22 @@ impl EventSender { /// A new client has been connected. pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult { - Ok(if let Some(client) = &self.inner { + if let Some(client) = &self.inner { match self.mask.connected { ConnectMode::None => {} ConnectMode::Notify => client.notify(Notify(f())).await?, ConnectMode::Request => client.rpc(f()).await??, } - }) + }; + Ok(()) } /// A new client has been connected. pub async fn connection_closed(&self, f: impl Fn() -> ConnectionClosed) -> ClientResult { - Ok(if let Some(client) = &self.inner { + if let Some(client) = &self.inner { client.notify(f()).await?; - }) + }; + Ok(()) } /// Abstract request, to DRY the 3 to 4 request types. @@ -300,58 +305,61 @@ impl EventSender { Notify>: Channels>, { - Ok(self.into_tracker(( - if let Some(client) = &self.inner { - match self.mask.get { - RequestMode::None => RequestUpdates::None, - RequestMode::Notify => { - let msg = RequestReceived { - request: f(), - connection_id, - request_id, - }; - RequestUpdates::Disabled(client.notify_streaming(Notify(msg), 32).await?) - } - RequestMode::Request => { - let msg = RequestReceived { - request: f(), - connection_id, - request_id, - }; - let (tx, rx) = client.client_streaming(msg, 32).await?; - // bail out if the request is not allowed - rx.await??; - RequestUpdates::Disabled(tx) - } - RequestMode::NotifyLog => { - let msg = RequestReceived { - request: f(), - connection_id, - request_id, - }; - RequestUpdates::Active(client.notify_streaming(Notify(msg), 32).await?) - } - RequestMode::RequestLog => { - let msg = RequestReceived { - request: f(), - connection_id, - request_id, - }; - let (tx, rx) = client.client_streaming(msg, 32).await?; - // bail out if the request is not allowed - rx.await??; - RequestUpdates::Active(tx) - } + let client = self.inner.as_ref(); + Ok(self.create_tracker(( + match self.mask.get { + RequestMode::None => RequestUpdates::None, + RequestMode::Notify if client.is_some() => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + RequestUpdates::Disabled( + client.unwrap().notify_streaming(Notify(msg), 32).await?, + ) } - } else { - RequestUpdates::None + RequestMode::Request if client.is_some() => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?; + // bail out if the request is not allowed + rx.await??; + RequestUpdates::Disabled(tx) + } + RequestMode::NotifyLog if client.is_some() => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + RequestUpdates::Active(client.unwrap().notify_streaming(Notify(msg), 32).await?) + } + RequestMode::RequestLog if client.is_some() => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?; + // bail out if the request is not allowed + rx.await??; + RequestUpdates::Active(tx) + } + RequestMode::Disabled => { + return Err(ClientError::Permission); + } + _ => RequestUpdates::None, }, connection_id, request_id, ))) } - fn into_tracker( + fn create_tracker( &self, (updates, connection_id, request_id): (RequestUpdates, u64, u64), ) -> RequestTracker { @@ -372,6 +380,7 @@ pub enum ProviderProto { /// A new client connected to the provider. #[rpc(tx = oneshot::Sender)] ClientConnected(ClientConnected), + /// A new client connected to the provider. Notify variant. #[rpc(tx = NoSender)] ClientConnectedNotify(Notify), diff --git a/src/tests.rs b/src/tests.rs index 911a5cfd2..40d9519c9 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -485,7 +485,7 @@ async fn check_presence(store: &Store, sizes: &[usize]) -> TestResult<()> { } pub async fn node_test_setup_fs(db_path: PathBuf) -> TestResult<(Router, FsStore, PathBuf)> { - node_test_setup_with_events_fs(db_path, EventSender::NONE).await + node_test_setup_with_events_fs(db_path, EventSender::DEFAULT).await } pub async fn node_test_setup_with_events_fs( @@ -500,7 +500,7 @@ pub async fn node_test_setup_with_events_fs( } pub async fn node_test_setup_mem() -> TestResult<(Router, MemStore)> { - node_test_setup_with_events_mem(EventSender::NONE).await + node_test_setup_with_events_mem(EventSender::DEFAULT).await } pub async fn node_test_setup_with_events_mem( From b26aefb4e2926fd980084cc95df56d8c3ff6ed45 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Mon, 1 Sep 2025 18:00:55 +0200 Subject: [PATCH 09/35] Footgun protection --- Cargo.lock | 4 ++ Cargo.toml | 2 +- examples/random_store.rs | 2 +- src/provider/events.rs | 94 ++++++++++++++++++++++++++++++++++++++-- src/tests.rs | 11 ++++- 5 files changed, 106 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1a4de777e..4068354f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1943,6 +1943,8 @@ dependencies = [ [[package]] name = "irpc" version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9f8f1d0987ea9da3d74698f921d0a817a214c83b2635a33ed4bc3efa4de1acd" dependencies = [ "anyhow", "futures-buffered", @@ -1964,6 +1966,8 @@ dependencies = [ [[package]] name = "irpc-derive" version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e0b26b834d401a046dd9d47bc236517c746eddbb5d25ff3e1a6075bfa4eebdb" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 3a642632c..bcd5f42d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ self_cell = "1.1.0" genawaiter = { version = "0.99.1", features = ["futures03"] } iroh-base = "0.91.1" reflink-copy = "0.1.24" -irpc = { version = "0.7.0", features = ["rpc", "quinn_endpoint_setup", "spans", "stream", "derive"], default-features = false, path = "../irpc" } +irpc = { version = "0.7.0", features = ["rpc", "quinn_endpoint_setup", "spans", "stream", "derive"], default-features = false } iroh-metrics = { version = "0.35" } [dev-dependencies] diff --git a/examples/random_store.rs b/examples/random_store.rs index f23e804e1..c4c30348b 100644 --- a/examples/random_store.rs +++ b/examples/random_store.rs @@ -176,7 +176,7 @@ pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, E } } }); - (dump_task, EventSender::new(tx, EventMask::ALL)) + (dump_task, EventSender::new(tx, EventMask::ALL_READONLY)) } #[tokio::main] diff --git a/src/provider/events.rs b/src/provider/events.rs index 1ecf13cb7..b7fc58daa 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -1,4 +1,4 @@ -use std::fmt::Debug; +use std::{fmt::Debug, ops::Deref}; use irpc::{ channel::{mpsc, none::NoSender, oneshot}, @@ -143,12 +143,16 @@ impl EventMask { observe: ObserveMode::None, }; - /// You get asked for every single thing that is going on and can intervene/throttle. - pub const ALL: Self = Self { + /// All event notifications for read-only requests are fully enabled. + /// + /// If you want to enable push requests, which can write to the local store, you + /// need to do it manually. Providing constants that have push enabled would + /// risk misuse. + pub const ALL_READONLY: Self = Self { connected: ConnectMode::Request, get: RequestMode::RequestLog, get_many: RequestMode::RequestLog, - push: RequestMode::RequestLog, + push: RequestMode::Disabled, throttle: ThrottleMode::Throttle, observe: ObserveMode::Request, }; @@ -158,6 +162,14 @@ impl EventMask { #[derive(Debug, Serialize, Deserialize)] pub struct Notify(T); +impl Deref for Notify { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + #[derive(Debug, Default, Clone)] pub struct EventSender { mask: EventMask, @@ -263,6 +275,80 @@ impl EventSender { } } + /// Log request events at trace level. + pub fn tracing(&self, mask: EventMask) -> Self { + use tracing::trace; + let (tx, mut rx) = tokio::sync::mpsc::channel(32); + n0_future::task::spawn(async move { + fn log_request_events( + mut rx: irpc::channel::mpsc::Receiver, + connection_id: u64, + request_id: u64, + ) { + n0_future::task::spawn(async move { + while let Ok(Some(update)) = rx.recv().await { + trace!(%connection_id, %request_id, "{update:?}"); + } + }); + } + while let Some(msg) = rx.recv().await { + match msg { + ProviderMessage::ClientConnected(_) => todo!(), + ProviderMessage::ClientConnectedNotify(msg) => { + trace!("{:?}", msg.inner); + } + ProviderMessage::ConnectionClosed(msg) => { + trace!("{:?}", msg.inner); + } + ProviderMessage::GetRequestReceived(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::GetRequestReceivedNotify(msg) => { + trace!("{:?}", msg.inner); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::GetManyRequestReceived(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::GetManyRequestReceivedNotify(msg) => { + trace!("{:?}", msg.inner); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::PushRequestReceived(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::PushRequestReceivedNotify(msg) => { + trace!("{:?}", msg.inner); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::ObserveRequestReceived(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::ObserveRequestReceivedNotify(msg) => { + trace!("{:?}", msg.inner); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::Throttle(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + } + } + } + }); + Self { + mask, + inner: Some(irpc::Client::from(tx)), + } + } + /// A new client has been connected. pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult { if let Some(client) = &self.inner { diff --git a/src/tests.rs b/src/tests.rs index 40d9519c9..dc38eb436 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -370,7 +370,16 @@ fn event_handler( } } })); - (EventSender::new(events_tx, EventMask::ALL), count_rx, task) + ( + EventSender::new( + events_tx, + EventMask { + ..EventMask::ALL_READONLY + }, + ), + count_rx, + task, + ) } async fn two_nodes_push_blobs( From 6d86e4f2403d73c46e1cf66909d5c38297cb3e36 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 2 Sep 2025 12:51:02 +0200 Subject: [PATCH 10/35] Add limit example This shows how to limit serving content in various ways - by node id - by content hash - throttling - limiting max number of connections --- examples/limit.rs | 341 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 341 insertions(+) create mode 100644 examples/limit.rs diff --git a/examples/limit.rs b/examples/limit.rs new file mode 100644 index 000000000..f23fad96c --- /dev/null +++ b/examples/limit.rs @@ -0,0 +1,341 @@ +/// Example how to limit blob requests by hash and node id, and to add +/// restrictions on limited content. +mod common; +use std::{ + collections::{HashMap, HashSet}, + path::PathBuf, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +use clap::Parser; +use common::setup_logging; +use iroh::{NodeId, SecretKey, Watcher}; +use iroh_blobs::{ + provider::events::{ + AbortReason, ConnectMode, EventMask, EventSender, ProviderMessage, RequestMode, + ThrottleMode, + }, + store::mem::MemStore, + ticket::BlobTicket, + BlobsProtocol, Hash, +}; +use rand::thread_rng; + +use crate::common::get_or_generate_secret_key; + +#[derive(Debug, Parser)] +#[command(version, about)] +pub enum Args { + ByNodeId { + /// Path for files to add + paths: Vec, + #[clap(long("allow"))] + /// Nodes that are allowed to download content. + allowed_nodes: Vec, + #[clap(long, default_value_t = 1)] + secrets: usize, + }, + ByHash { + /// Path for files to add + paths: Vec, + }, + Throttle { + /// Path for files to add + paths: Vec, + #[clap(long, default_value = "100")] + delay_ms: u64, + }, + MaxConnections { + /// Path for files to add + paths: Vec, + #[clap(long, default_value = "1")] + max_connections: usize, + }, + Get { + /// Ticket for the blob to download + ticket: BlobTicket, + }, +} + +fn limit_by_node_id(allowed_nodes: HashSet) -> EventSender { + let (tx, mut rx) = tokio::sync::mpsc::channel(32); + n0_future::task::spawn(async move { + while let Some(msg) = rx.recv().await { + match msg { + ProviderMessage::ClientConnected(msg) => { + let node_id = msg.node_id; + let res = if allowed_nodes.contains(&node_id) { + println!("Client connected: {node_id}"); + Ok(()) + } else { + println!("Client rejected: {node_id}"); + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); + } + _ => {} + } + } + }); + EventSender::new( + tx, + EventMask { + connected: ConnectMode::Request, + ..EventMask::DEFAULT + }, + ) +} + +fn limit_by_hash(allowed_hashes: HashSet) -> EventSender { + let (tx, mut rx) = tokio::sync::mpsc::channel(32); + n0_future::task::spawn(async move { + while let Some(msg) = rx.recv().await { + match msg { + ProviderMessage::GetRequestReceived(msg) => { + let res = if !msg.request.ranges.is_blob() { + println!("HashSeq request not allowed"); + Err(AbortReason::Permission) + } else if !allowed_hashes.contains(&msg.request.hash) { + println!("Request for hash {} not allowed", msg.request.hash); + Err(AbortReason::Permission) + } else { + println!("Request for hash {} allowed", msg.request.hash); + Ok(()) + }; + msg.tx.send(res).await.ok(); + } + _ => {} + } + } + }); + EventSender::new( + tx, + EventMask { + get: RequestMode::Request, + ..EventMask::DEFAULT + }, + ) +} + +fn throttle(delay_ms: u64) -> EventSender { + let (tx, mut rx) = tokio::sync::mpsc::channel(32); + n0_future::task::spawn(async move { + while let Some(msg) = rx.recv().await { + match msg { + ProviderMessage::Throttle(msg) => { + n0_future::task::spawn(async move { + println!( + "Throttling {} {}, {}ms", + msg.connection_id, msg.request_id, delay_ms + ); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + msg.tx.send(Ok(())).await.ok(); + }); + } + _ => {} + } + } + }); + EventSender::new( + tx, + EventMask { + throttle: ThrottleMode::Throttle, + ..EventMask::DEFAULT + }, + ) +} + +fn limit_max_connections(max_connections: usize) -> EventSender { + let (tx, mut rx) = tokio::sync::mpsc::channel(32); + n0_future::task::spawn(async move { + let requests = Arc::new(AtomicUsize::new(0)); + while let Some(msg) = rx.recv().await { + match msg { + ProviderMessage::GetRequestReceived(mut msg) => { + let connection_id = msg.connection_id; + let request_id = msg.request_id; + let res = requests.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| { + if n >= max_connections { + None + } else { + Some(n + 1) + } + }); + match res { + Ok(n) => { + println!("Accepting request {n}, id ({connection_id},{request_id})"); + msg.tx.send(Ok(())).await.ok(); + } + Err(_) => { + println!( + "Connection limit of {} exceeded, rejecting request", + max_connections + ); + msg.tx.send(Err(AbortReason::RateLimited)).await.ok(); + continue; + } + } + let requests = requests.clone(); + n0_future::task::spawn(async move { + // just drain the per request events + while let Ok(Some(_)) = msg.rx.recv().await {} + println!("Stopping request, id ({connection_id},{request_id})"); + requests.fetch_sub(1, Ordering::SeqCst); + }); + } + _ => {} + } + } + }); + EventSender::new( + tx, + EventMask { + get: RequestMode::RequestLog, + ..EventMask::DEFAULT + }, + ) +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + setup_logging(); + let args = Args::parse(); + match args { + Args::Get { ticket } => { + let secret = get_or_generate_secret_key()?; + let endpoint = iroh::Endpoint::builder() + .secret_key(secret) + .discovery_n0() + .bind() + .await?; + let connection = endpoint + .connect(ticket.node_addr().clone(), iroh_blobs::ALPN) + .await?; + let (data, stats) = iroh_blobs::get::request::get_blob(connection, ticket.hash()) + .bytes_and_stats() + .await?; + println!("Downloaded {} bytes", data.len()); + println!("Stats: {:?}", stats); + } + Args::ByNodeId { + paths, + allowed_nodes, + secrets, + } => { + let mut allowed_nodes = allowed_nodes.into_iter().collect::>(); + if secrets > 0 { + println!("Generating {secrets} new secret keys for allowed nodes:"); + let mut rand = thread_rng(); + for _ in 0..secrets { + let secret = SecretKey::generate(&mut rand); + let public = secret.public(); + allowed_nodes.insert(public); + println!("IROH_SECRET={}", hex::encode(secret.to_bytes())); + } + } + let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?; + let store = MemStore::new(); + let mut hashes = HashMap::new(); + for path in paths { + let tag = store.add_path(&path).await?; + hashes.insert(path, tag.hash); + } + let _ = endpoint.home_relay().initialized().await; + let addr = endpoint.node_addr().initialized().await; + let events = limit_by_node_id(allowed_nodes.clone()); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); + let router = iroh::protocol::Router::builder(endpoint) + .accept(iroh_blobs::ALPN, blobs) + .spawn(); + println!("Node id: {}\n", router.endpoint().node_id()); + for id in &allowed_nodes { + println!("Allowed node: {id}"); + } + println!(); + for (path, hash) in &hashes { + let ticket = BlobTicket::new(addr.clone(), *hash, iroh_blobs::BlobFormat::Raw); + println!("{}: {ticket}", path.display()); + } + tokio::signal::ctrl_c().await?; + router.shutdown().await?; + } + Args::ByHash { paths } => { + let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?; + let store = MemStore::new(); + let mut hashes = HashMap::new(); + let mut allowed_hashes = HashSet::new(); + for (i, path) in paths.into_iter().enumerate() { + let tag = store.add_path(&path).await?; + hashes.insert(path, tag.hash); + if i == 0 { + allowed_hashes.insert(tag.hash); + } + } + let _ = endpoint.home_relay().initialized().await; + let addr = endpoint.node_addr().initialized().await; + let events = limit_by_hash(allowed_hashes.clone()); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); + let router = iroh::protocol::Router::builder(endpoint) + .accept(iroh_blobs::ALPN, blobs) + .spawn(); + for (i, (path, hash)) in hashes.iter().enumerate() { + let ticket = BlobTicket::new(addr.clone(), *hash, iroh_blobs::BlobFormat::Raw); + let permitted = if i == 0 { "" } else { "limited" }; + println!("{}: {ticket} ({permitted})", path.display()); + } + tokio::signal::ctrl_c().await?; + router.shutdown().await?; + } + Args::Throttle { paths, delay_ms } => { + let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?; + let store = MemStore::new(); + let mut hashes = HashMap::new(); + for path in paths { + let tag = store.add_path(&path).await?; + hashes.insert(path, tag.hash); + } + let _ = endpoint.home_relay().initialized().await; + let addr = endpoint.node_addr().initialized().await; + let events = throttle(delay_ms); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); + let router = iroh::protocol::Router::builder(endpoint) + .accept(iroh_blobs::ALPN, blobs) + .spawn(); + for (path, hash) in hashes { + let ticket = BlobTicket::new(addr.clone(), hash, iroh_blobs::BlobFormat::Raw); + println!("{}: {ticket}", path.display()); + } + tokio::signal::ctrl_c().await?; + router.shutdown().await?; + } + Args::MaxConnections { + paths, + max_connections, + } => { + let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?; + let store = MemStore::new(); + let mut hashes = HashMap::new(); + for path in paths { + let tag = store.add_path(&path).await?; + hashes.insert(path, tag.hash); + } + let _ = endpoint.home_relay().initialized().await; + let addr = endpoint.node_addr().initialized().await; + let events = limit_max_connections(max_connections); + let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); + let router = iroh::protocol::Router::builder(endpoint) + .accept(iroh_blobs::ALPN, blobs) + .spawn(); + for (path, hash) in hashes { + let ticket = BlobTicket::new(addr.clone(), hash, iroh_blobs::BlobFormat::Raw); + println!("{}: {ticket}", path.display()); + } + tokio::signal::ctrl_c().await?; + router.shutdown().await?; + } + } + Ok(()) +} From 4b87b6dcbc03540a3f36f93c0d8930386729b37b Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 2 Sep 2025 14:18:16 +0200 Subject: [PATCH 11/35] Add len to notify_payload_write --- src/provider.rs | 7 ++++--- src/provider/events.rs | 5 ++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/provider.rs b/src/provider.rs index 1683daa57..883f97811 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -234,9 +234,10 @@ impl WriterContext { impl WriteProgress for WriterContext { async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) { - let end_offset = offset + len as u64; - self.payload_bytes_written += len as u64; - self.tracker.transfer_progress(end_offset).await.ok(); + let len = len as u64; + let end_offset = offset + len; + self.payload_bytes_written += len; + self.tracker.transfer_progress(len, end_offset).await.ok(); } fn log_other_write(&mut self, len: usize) { diff --git a/src/provider/events.rs b/src/provider/events.rs index b7fc58daa..35b641011 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -222,7 +222,7 @@ impl RequestTracker { } /// Transfer progress for the previously reported blob, end_offset is the new end offset in bytes. - pub async fn transfer_progress(&mut self, end_offset: u64) -> ClientResult { + pub async fn transfer_progress(&mut self, len: u64, end_offset: u64) -> ClientResult { if let RequestUpdates::Active(tx) = &mut self.updates { tx.try_send(RequestUpdate::Progress(TransferProgress { end_offset })) .await?; @@ -232,6 +232,7 @@ impl RequestTracker { .rpc(Throttle { connection_id: *connection_id, request_id: *request_id, + size: len as u64, }) .await??; } @@ -546,6 +547,8 @@ mod proto { pub connection_id: u64, /// The request id. There is a new id for each request. pub request_id: u64, + /// Size of the chunk to be throttled. This will usually be 16 KiB. + pub size: u64, } #[derive(Debug, Serialize, Deserialize)] From f992a448a55aa84da6a9d5e1e5f9f203aad266d6 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 2 Sep 2025 14:18:54 +0200 Subject: [PATCH 12/35] clippy --- examples/limit.rs | 133 +++++++++++++++++++---------------------- src/provider/events.rs | 2 +- 2 files changed, 61 insertions(+), 74 deletions(-) diff --git a/examples/limit.rs b/examples/limit.rs index f23fad96c..09e1be132 100644 --- a/examples/limit.rs +++ b/examples/limit.rs @@ -64,19 +64,16 @@ fn limit_by_node_id(allowed_nodes: HashSet) -> EventSender { let (tx, mut rx) = tokio::sync::mpsc::channel(32); n0_future::task::spawn(async move { while let Some(msg) = rx.recv().await { - match msg { - ProviderMessage::ClientConnected(msg) => { - let node_id = msg.node_id; - let res = if allowed_nodes.contains(&node_id) { - println!("Client connected: {node_id}"); - Ok(()) - } else { - println!("Client rejected: {node_id}"); - Err(AbortReason::Permission) - }; - msg.tx.send(res).await.ok(); - } - _ => {} + if let ProviderMessage::ClientConnected(msg) = msg { + let node_id = msg.node_id; + let res = if allowed_nodes.contains(&node_id) { + println!("Client connected: {node_id}"); + Ok(()) + } else { + println!("Client rejected: {node_id}"); + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); } } }); @@ -93,21 +90,18 @@ fn limit_by_hash(allowed_hashes: HashSet) -> EventSender { let (tx, mut rx) = tokio::sync::mpsc::channel(32); n0_future::task::spawn(async move { while let Some(msg) = rx.recv().await { - match msg { - ProviderMessage::GetRequestReceived(msg) => { - let res = if !msg.request.ranges.is_blob() { - println!("HashSeq request not allowed"); - Err(AbortReason::Permission) - } else if !allowed_hashes.contains(&msg.request.hash) { - println!("Request for hash {} not allowed", msg.request.hash); - Err(AbortReason::Permission) - } else { - println!("Request for hash {} allowed", msg.request.hash); - Ok(()) - }; - msg.tx.send(res).await.ok(); - } - _ => {} + if let ProviderMessage::GetRequestReceived(msg) = msg { + let res = if !msg.request.ranges.is_blob() { + println!("HashSeq request not allowed"); + Err(AbortReason::Permission) + } else if !allowed_hashes.contains(&msg.request.hash) { + println!("Request for hash {} not allowed", msg.request.hash); + Err(AbortReason::Permission) + } else { + println!("Request for hash {} allowed", msg.request.hash); + Ok(()) + }; + msg.tx.send(res).await.ok(); } } }); @@ -124,18 +118,15 @@ fn throttle(delay_ms: u64) -> EventSender { let (tx, mut rx) = tokio::sync::mpsc::channel(32); n0_future::task::spawn(async move { while let Some(msg) = rx.recv().await { - match msg { - ProviderMessage::Throttle(msg) => { - n0_future::task::spawn(async move { - println!( - "Throttling {} {}, {}ms", - msg.connection_id, msg.request_id, delay_ms - ); - tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; - msg.tx.send(Ok(())).await.ok(); - }); - } - _ => {} + if let ProviderMessage::Throttle(msg) = msg { + n0_future::task::spawn(async move { + println!( + "Throttling {} {}, {}ms", + msg.connection_id, msg.request_id, delay_ms + ); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + msg.tx.send(Ok(())).await.ok(); + }); } } }); @@ -153,40 +144,36 @@ fn limit_max_connections(max_connections: usize) -> EventSender { n0_future::task::spawn(async move { let requests = Arc::new(AtomicUsize::new(0)); while let Some(msg) = rx.recv().await { - match msg { - ProviderMessage::GetRequestReceived(mut msg) => { - let connection_id = msg.connection_id; - let request_id = msg.request_id; - let res = requests.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| { - if n >= max_connections { - None - } else { - Some(n + 1) - } - }); - match res { - Ok(n) => { - println!("Accepting request {n}, id ({connection_id},{request_id})"); - msg.tx.send(Ok(())).await.ok(); - } - Err(_) => { - println!( - "Connection limit of {} exceeded, rejecting request", - max_connections - ); - msg.tx.send(Err(AbortReason::RateLimited)).await.ok(); - continue; - } + if let ProviderMessage::GetRequestReceived(mut msg) = msg { + let connection_id = msg.connection_id; + let request_id = msg.request_id; + let res = requests.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| { + if n >= max_connections { + None + } else { + Some(n + 1) + } + }); + match res { + Ok(n) => { + println!("Accepting request {n}, id ({connection_id},{request_id})"); + msg.tx.send(Ok(())).await.ok(); + } + Err(_) => { + println!( + "Connection limit of {max_connections} exceeded, rejecting request" + ); + msg.tx.send(Err(AbortReason::RateLimited)).await.ok(); + continue; } - let requests = requests.clone(); - n0_future::task::spawn(async move { - // just drain the per request events - while let Ok(Some(_)) = msg.rx.recv().await {} - println!("Stopping request, id ({connection_id},{request_id})"); - requests.fetch_sub(1, Ordering::SeqCst); - }); } - _ => {} + let requests = requests.clone(); + n0_future::task::spawn(async move { + // just drain the per request events + while let Ok(Some(_)) = msg.rx.recv().await {} + println!("Stopping request, id ({connection_id},{request_id})"); + requests.fetch_sub(1, Ordering::SeqCst); + }); } } }); @@ -218,7 +205,7 @@ async fn main() -> anyhow::Result<()> { .bytes_and_stats() .await?; println!("Downloaded {} bytes", data.len()); - println!("Stats: {:?}", stats); + println!("Stats: {stats:?}"); } Args::ByNodeId { paths, diff --git a/src/provider/events.rs b/src/provider/events.rs index 35b641011..5e5972167 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -232,7 +232,7 @@ impl RequestTracker { .rpc(Throttle { connection_id: *connection_id, request_id: *request_id, - size: len as u64, + size: len, }) .await??; } From 4bddf77939f672d841f09fc2c522176c1e94a775 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 2 Sep 2025 14:27:31 +0200 Subject: [PATCH 13/35] nicer connection counter --- examples/limit.rs | 44 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/examples/limit.rs b/examples/limit.rs index 09e1be132..e7b86a8ca 100644 --- a/examples/limit.rs +++ b/examples/limit.rs @@ -29,6 +29,7 @@ use crate::common::get_or_generate_secret_key; #[derive(Debug, Parser)] #[command(version, about)] pub enum Args { + /// Limit requests by node id ByNodeId { /// Path for files to add paths: Vec, @@ -38,16 +39,19 @@ pub enum Args { #[clap(long, default_value_t = 1)] secrets: usize, }, + /// Limit requests by hash, only first hash is allowed ByHash { /// Path for files to add paths: Vec, }, + /// Throttle requests Throttle { /// Path for files to add paths: Vec, #[clap(long, default_value = "100")] delay_ms: u64, }, + /// Limit maximum number of connections. MaxConnections { /// Path for files to add paths: Vec, @@ -140,20 +144,39 @@ fn throttle(delay_ms: u64) -> EventSender { } fn limit_max_connections(max_connections: usize) -> EventSender { + #[derive(Default, Debug, Clone)] + struct ConnectionCounter(Arc<(AtomicUsize, usize)>); + + impl ConnectionCounter { + fn new(max: usize) -> Self { + Self(Arc::new((Default::default(), max))) + } + + fn inc(&self) -> Result { + let (c, max) = &*self.0; + c.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| { + if n >= *max { + None + } else { + Some(n + 1) + } + }) + } + + fn dec(&self) { + let (c, _) = &*self.0; + c.fetch_sub(1, Ordering::SeqCst); + } + } + let (tx, mut rx) = tokio::sync::mpsc::channel(32); n0_future::task::spawn(async move { - let requests = Arc::new(AtomicUsize::new(0)); + let requests = ConnectionCounter::new(max_connections); while let Some(msg) = rx.recv().await { if let ProviderMessage::GetRequestReceived(mut msg) = msg { let connection_id = msg.connection_id; let request_id = msg.request_id; - let res = requests.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| { - if n >= max_connections { - None - } else { - Some(n + 1) - } - }); + let res = requests.inc(); match res { Ok(n) => { println!("Accepting request {n}, id ({connection_id},{request_id})"); @@ -170,9 +193,12 @@ fn limit_max_connections(max_connections: usize) -> EventSender { let requests = requests.clone(); n0_future::task::spawn(async move { // just drain the per request events + // + // Note that we have requested updates for the request, now we also need to process them + // otherwise the request will be aborted! while let Ok(Some(_)) = msg.rx.recv().await {} println!("Stopping request, id ({connection_id},{request_id})"); - requests.fetch_sub(1, Ordering::SeqCst); + requests.dec(); }); } } From 33333a9afc5659aa1c341474ee40a9952f0e49da Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 2 Sep 2025 14:38:36 +0200 Subject: [PATCH 14/35] Add docs for the limit example. --- examples/limit.rs | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/examples/limit.rs b/examples/limit.rs index e7b86a8ca..fff910da3 100644 --- a/examples/limit.rs +++ b/examples/limit.rs @@ -1,5 +1,13 @@ /// Example how to limit blob requests by hash and node id, and to add -/// restrictions on limited content. +/// throttling or limiting the maximum number of connections. +/// +/// Limiting is done via a fn that returns an EventSender and internally +/// makes liberal use of spawn to spawn background tasks. +/// +/// This is fine, since the tasks will terminate as soon as the [BlobsProtocol] +/// instance holding the [EventSender] will be dropped. But for production +/// grade code you might nevertheless put the tasks into a [tokio::task::JoinSet] or +/// [n0_future::FuturesUnordered]. mod common; use std::{ collections::{HashMap, HashSet}, @@ -31,33 +39,37 @@ use crate::common::get_or_generate_secret_key; pub enum Args { /// Limit requests by node id ByNodeId { - /// Path for files to add + /// Path for files to add. paths: Vec, #[clap(long("allow"))] /// Nodes that are allowed to download content. allowed_nodes: Vec, + /// Number of secrets to generate for allowed node ids. #[clap(long, default_value_t = 1)] secrets: usize, }, /// Limit requests by hash, only first hash is allowed ByHash { - /// Path for files to add + /// Path for files to add. paths: Vec, }, /// Throttle requests Throttle { - /// Path for files to add + /// Path for files to add. paths: Vec, + /// Delay in milliseconds after sending a chunk group of 16 KiB. #[clap(long, default_value = "100")] delay_ms: u64, }, /// Limit maximum number of connections. MaxConnections { - /// Path for files to add + /// Path for files to add. paths: Vec, + /// Maximum number of concurrent get requests. #[clap(long, default_value = "1")] max_connections: usize, }, + /// Get a blob. Just for completeness sake. Get { /// Ticket for the blob to download ticket: BlobTicket, @@ -84,6 +96,8 @@ fn limit_by_node_id(allowed_nodes: HashSet) -> EventSender { EventSender::new( tx, EventMask { + // We want a request for each incoming connection so we can accept + // or reject them. We don't need any other events. connected: ConnectMode::Request, ..EventMask::DEFAULT }, @@ -112,6 +126,9 @@ fn limit_by_hash(allowed_hashes: HashSet) -> EventSender { EventSender::new( tx, EventMask { + // We want to get a request for each get request that we can answer + // with OK or not OK depending on the hash. We do not want detailed + // events once it has been decided to handle a request. get: RequestMode::Request, ..EventMask::DEFAULT }, @@ -128,6 +145,8 @@ fn throttle(delay_ms: u64) -> EventSender { "Throttling {} {}, {}ms", msg.connection_id, msg.request_id, delay_ms ); + // we could compute the delay from the size of the data to have a fixed rate. + // but the size is almost always 16 KiB (16 chunks). tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; msg.tx.send(Ok(())).await.ok(); }); @@ -137,6 +156,8 @@ fn throttle(delay_ms: u64) -> EventSender { EventSender::new( tx, EventMask { + // We want to get requests for each sent user data blob, so we can add a delay. + // Other than that, we don't need any events. throttle: ThrottleMode::Throttle, ..EventMask::DEFAULT }, @@ -206,6 +227,10 @@ fn limit_max_connections(max_connections: usize) -> EventSender { EventSender::new( tx, EventMask { + // For each get request, we want to get a request so we can decide + // based on the current connection count if we want to accept or reject. + // We also want detailed logging of events for the get request, so we can + // detect when the request is finished one way or another. get: RequestMode::RequestLog, ..EventMask::DEFAULT }, From 9a62a581b1cc903f5d8c5dcf8fd947a3f6cbd980 Mon Sep 17 00:00:00 2001 From: Frando Date: Wed, 3 Sep 2025 13:43:44 +0200 Subject: [PATCH 15/35] refactor: make limits example more DRY --- examples/limit.rs | 184 +++++++++++++++++++---------------------- src/provider/events.rs | 8 ++ 2 files changed, 94 insertions(+), 98 deletions(-) diff --git a/examples/limit.rs b/examples/limit.rs index fff910da3..2be92358a 100644 --- a/examples/limit.rs +++ b/examples/limit.rs @@ -18,9 +18,10 @@ use std::{ }, }; +use anyhow::Result; use clap::Parser; use common::setup_logging; -use iroh::{NodeId, SecretKey, Watcher}; +use iroh::{protocol::Router, NodeAddr, NodeId, SecretKey, Watcher}; use iroh_blobs::{ provider::events::{ AbortReason, ConnectMode, EventMask, EventSender, ProviderMessage, RequestMode, @@ -28,7 +29,7 @@ use iroh_blobs::{ }, store::mem::MemStore, ticket::BlobTicket, - BlobsProtocol, Hash, + BlobFormat, BlobsProtocol, Hash, }; use rand::thread_rng; @@ -77,7 +78,13 @@ pub enum Args { } fn limit_by_node_id(allowed_nodes: HashSet) -> EventSender { - let (tx, mut rx) = tokio::sync::mpsc::channel(32); + let mask = EventMask { + // We want a request for each incoming connection so we can accept + // or reject them. We don't need any other events. + connected: ConnectMode::Request, + ..EventMask::DEFAULT + }; + let (tx, mut rx) = EventSender::channel(32, mask); n0_future::task::spawn(async move { while let Some(msg) = rx.recv().await { if let ProviderMessage::ClientConnected(msg) = msg { @@ -93,19 +100,18 @@ fn limit_by_node_id(allowed_nodes: HashSet) -> EventSender { } } }); - EventSender::new( - tx, - EventMask { - // We want a request for each incoming connection so we can accept - // or reject them. We don't need any other events. - connected: ConnectMode::Request, - ..EventMask::DEFAULT - }, - ) + tx } fn limit_by_hash(allowed_hashes: HashSet) -> EventSender { - let (tx, mut rx) = tokio::sync::mpsc::channel(32); + let mask = EventMask { + // We want to get a request for each get request that we can answer + // with OK or not OK depending on the hash. We do not want detailed + // events once it has been decided to handle a request. + get: RequestMode::Request, + ..EventMask::DEFAULT + }; + let (tx, mut rx) = EventSender::channel(32, mask); n0_future::task::spawn(async move { while let Some(msg) = rx.recv().await { if let ProviderMessage::GetRequestReceived(msg) = msg { @@ -123,20 +129,17 @@ fn limit_by_hash(allowed_hashes: HashSet) -> EventSender { } } }); - EventSender::new( - tx, - EventMask { - // We want to get a request for each get request that we can answer - // with OK or not OK depending on the hash. We do not want detailed - // events once it has been decided to handle a request. - get: RequestMode::Request, - ..EventMask::DEFAULT - }, - ) + tx } fn throttle(delay_ms: u64) -> EventSender { - let (tx, mut rx) = tokio::sync::mpsc::channel(32); + let mask = EventMask { + // We want to get requests for each sent user data blob, so we can add a delay. + // Other than that, we don't need any events. + throttle: ThrottleMode::Throttle, + ..EventMask::DEFAULT + }; + let (tx, mut rx) = EventSender::channel(32, mask); n0_future::task::spawn(async move { while let Some(msg) = rx.recv().await { if let ProviderMessage::Throttle(msg) = msg { @@ -153,15 +156,7 @@ fn throttle(delay_ms: u64) -> EventSender { } } }); - EventSender::new( - tx, - EventMask { - // We want to get requests for each sent user data blob, so we can add a delay. - // Other than that, we don't need any events. - throttle: ThrottleMode::Throttle, - ..EventMask::DEFAULT - }, - ) + tx } fn limit_max_connections(max_connections: usize) -> EventSender { @@ -190,7 +185,15 @@ fn limit_max_connections(max_connections: usize) -> EventSender { } } - let (tx, mut rx) = tokio::sync::mpsc::channel(32); + let mask = EventMask { + // For each get request, we want to get a request so we can decide + // based on the current connection count if we want to accept or reject. + // We also want detailed logging of events for the get request, so we can + // detect when the request is finished one way or another. + get: RequestMode::RequestLog, + ..EventMask::DEFAULT + }; + let (tx, mut rx) = EventSender::channel(32, mask); n0_future::task::spawn(async move { let requests = ConnectionCounter::new(max_connections); while let Some(msg) = rx.recv().await { @@ -224,21 +227,11 @@ fn limit_max_connections(max_connections: usize) -> EventSender { } } }); - EventSender::new( - tx, - EventMask { - // For each get request, we want to get a request so we can decide - // based on the current connection count if we want to accept or reject. - // We also want detailed logging of events for the get request, so we can - // detect when the request is finished one way or another. - get: RequestMode::RequestLog, - ..EventMask::DEFAULT - }, - ) + tx } #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> Result<()> { setup_logging(); let args = Args::parse(); match args { @@ -274,35 +267,28 @@ async fn main() -> anyhow::Result<()> { println!("IROH_SECRET={}", hex::encode(secret.to_bytes())); } } - let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?; + let store = MemStore::new(); - let mut hashes = HashMap::new(); - for path in paths { - let tag = store.add_path(&path).await?; - hashes.insert(path, tag.hash); - } - let _ = endpoint.home_relay().initialized().await; - let addr = endpoint.node_addr().initialized().await; + let hashes = add_paths(&store, paths).await?; let events = limit_by_node_id(allowed_nodes.clone()); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); - let router = iroh::protocol::Router::builder(endpoint) - .accept(iroh_blobs::ALPN, blobs) - .spawn(); + let (router, addr) = setup(MemStore::new(), events).await?; + + for (path, hash) in hashes { + let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); + println!("{}: {ticket}", path.display()); + } + println!(); println!("Node id: {}\n", router.endpoint().node_id()); for id in &allowed_nodes { println!("Allowed node: {id}"); } - println!(); - for (path, hash) in &hashes { - let ticket = BlobTicket::new(addr.clone(), *hash, iroh_blobs::BlobFormat::Raw); - println!("{}: {ticket}", path.display()); - } + tokio::signal::ctrl_c().await?; router.shutdown().await?; } Args::ByHash { paths } => { - let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?; let store = MemStore::new(); + let mut hashes = HashMap::new(); let mut allowed_hashes = HashSet::new(); for (i, path) in paths.into_iter().enumerate() { @@ -312,15 +298,12 @@ async fn main() -> anyhow::Result<()> { allowed_hashes.insert(tag.hash); } } - let _ = endpoint.home_relay().initialized().await; - let addr = endpoint.node_addr().initialized().await; - let events = limit_by_hash(allowed_hashes.clone()); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); - let router = iroh::protocol::Router::builder(endpoint) - .accept(iroh_blobs::ALPN, blobs) - .spawn(); + + let events = limit_by_hash(allowed_hashes); + let (router, addr) = setup(MemStore::new(), events).await?; + for (i, (path, hash)) in hashes.iter().enumerate() { - let ticket = BlobTicket::new(addr.clone(), *hash, iroh_blobs::BlobFormat::Raw); + let ticket = BlobTicket::new(addr.clone(), *hash, BlobFormat::Raw); let permitted = if i == 0 { "" } else { "limited" }; println!("{}: {ticket} ({permitted})", path.display()); } @@ -328,22 +311,12 @@ async fn main() -> anyhow::Result<()> { router.shutdown().await?; } Args::Throttle { paths, delay_ms } => { - let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?; let store = MemStore::new(); - let mut hashes = HashMap::new(); - for path in paths { - let tag = store.add_path(&path).await?; - hashes.insert(path, tag.hash); - } - let _ = endpoint.home_relay().initialized().await; - let addr = endpoint.node_addr().initialized().await; + let hashes = add_paths(&store, paths).await?; let events = throttle(delay_ms); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); - let router = iroh::protocol::Router::builder(endpoint) - .accept(iroh_blobs::ALPN, blobs) - .spawn(); + let (router, addr) = setup(MemStore::new(), events).await?; for (path, hash) in hashes { - let ticket = BlobTicket::new(addr.clone(), hash, iroh_blobs::BlobFormat::Raw); + let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); println!("{}: {ticket}", path.display()); } tokio::signal::ctrl_c().await?; @@ -353,22 +326,12 @@ async fn main() -> anyhow::Result<()> { paths, max_connections, } => { - let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?; let store = MemStore::new(); - let mut hashes = HashMap::new(); - for path in paths { - let tag = store.add_path(&path).await?; - hashes.insert(path, tag.hash); - } - let _ = endpoint.home_relay().initialized().await; - let addr = endpoint.node_addr().initialized().await; + let hashes = add_paths(&store, paths).await?; let events = limit_max_connections(max_connections); - let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); - let router = iroh::protocol::Router::builder(endpoint) - .accept(iroh_blobs::ALPN, blobs) - .spawn(); + let (router, addr) = setup(MemStore::new(), events).await?; for (path, hash) in hashes { - let ticket = BlobTicket::new(addr.clone(), hash, iroh_blobs::BlobFormat::Raw); + let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); println!("{}: {ticket}", path.display()); } tokio::signal::ctrl_c().await?; @@ -377,3 +340,28 @@ async fn main() -> anyhow::Result<()> { } Ok(()) } + +async fn add_paths(store: &MemStore, paths: Vec) -> Result> { + let mut hashes = HashMap::new(); + for path in paths { + let tag = store.add_path(&path).await?; + hashes.insert(path, tag.hash); + } + Ok(hashes) +} + +async fn setup(store: MemStore, events: EventSender) -> Result<(Router, NodeAddr)> { + let secret = get_or_generate_secret_key()?; + let endpoint = iroh::Endpoint::builder() + .discovery_n0() + .secret_key(secret) + .bind() + .await?; + let _ = endpoint.home_relay().initialized().await; + let addr = endpoint.node_addr().initialized().await; + let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); + let router = Router::builder(endpoint) + .accept(iroh_blobs::ALPN, blobs) + .spawn(); + Ok((router, addr)) +} diff --git a/src/provider/events.rs b/src/provider/events.rs index 5e5972167..f2bddb23c 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -276,6 +276,14 @@ impl EventSender { } } + pub fn channel( + capacity: usize, + mask: EventMask, + ) -> (Self, tokio::sync::mpsc::Receiver) { + let (tx, rx) = tokio::sync::mpsc::channel(capacity); + (Self::new(tx, mask), rx) + } + /// Log request events at trace level. pub fn tracing(&self, mask: EventMask) -> Self { use tracing::trace; From 071db5e0a6c69edcac2a3d42de7103cfabcd866c Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 3 Sep 2025 15:05:11 +0200 Subject: [PATCH 16/35] Make sure to send a proper reset code when resetting a connection so the other side can know if reconnecting is OK --- examples/limit.rs | 12 +++--- src/protocol.rs | 7 +++ src/provider.rs | 97 +++++++++++++++++++++++++++--------------- src/provider/events.rs | 47 ++++++++++++++------ 4 files changed, 110 insertions(+), 53 deletions(-) diff --git a/examples/limit.rs b/examples/limit.rs index 2be92358a..830574fcc 100644 --- a/examples/limit.rs +++ b/examples/limit.rs @@ -234,14 +234,14 @@ fn limit_max_connections(max_connections: usize) -> EventSender { async fn main() -> Result<()> { setup_logging(); let args = Args::parse(); + let secret = get_or_generate_secret_key()?; + let endpoint = iroh::Endpoint::builder() + .secret_key(secret) + .discovery_n0() + .bind() + .await?; match args { Args::Get { ticket } => { - let secret = get_or_generate_secret_key()?; - let endpoint = iroh::Endpoint::builder() - .secret_key(secret) - .discovery_n0() - .bind() - .await?; let connection = endpoint .connect(ticket.node_addr().clone(), iroh_blobs::ALPN) .await?; diff --git a/src/protocol.rs b/src/protocol.rs index 05ee00678..ce10865a5 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -397,6 +397,13 @@ use crate::{api::blobs::Bitfield, provider::RecvStreamExt, BlobFormat, Hash, Has /// Maximum message size is limited to 100MiB for now. pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; +/// Error code for a permission error +pub const ERR_PERMISSION: VarInt = VarInt::from_u32(1u32); +/// Error code for when a request is aborted due to a rate limit +pub const ERR_LIMIT: VarInt = VarInt::from_u32(2u32); +/// Error code for when a request is aborted due to internal error +pub const ERR_INTERNAL: VarInt = VarInt::from_u32(3u32); + /// The ALPN used with quic for the iroh blobs protocol. pub const ALPN: &[u8] = b"/iroh-bytes/4"; diff --git a/src/provider.rs b/src/provider.rs index 883f97811..49b57e13a 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -20,13 +20,12 @@ use tracing::{debug, debug_span, warn, Instrument}; use crate::{ api::{ - self, blobs::{Bitfield, WriteProgress}, - Store, + ExportBaoResult, Store, }, hashseq::HashSeq, protocol::{GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request}, - provider::events::{ClientConnected, ClientError, ConnectionClosed, RequestTracker}, + provider::events::{ClientConnected, ConnectionClosed, RequestTracker}, Hash, }; pub mod events; @@ -94,7 +93,7 @@ impl StreamPair { } /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id - pub async fn into_writer( + async fn into_writer( mut self, tracker: RequestTracker, ) -> Result { @@ -118,7 +117,7 @@ impl StreamPair { )) } - pub async fn into_reader( + async fn into_reader( mut self, tracker: RequestTracker, ) -> Result { @@ -141,39 +140,71 @@ impl StreamPair { } pub async fn get_request( - &self, + mut self, f: impl FnOnce() -> GetRequest, - ) -> Result { - self.events + ) -> anyhow::Result { + let res = self + .events .request(f, self.connection_id, self.request_id) - .await + .await; + match res { + Err(e) => { + self.writer.reset(e.code()).ok(); + Err(e.into()) + } + Ok(tracker) => Ok(self.into_writer(tracker).await?), + } } pub async fn get_many_request( - &self, + mut self, f: impl FnOnce() -> GetManyRequest, - ) -> Result { - self.events + ) -> anyhow::Result { + let res = self + .events .request(f, self.connection_id, self.request_id) - .await + .await; + match res { + Err(e) => { + self.writer.reset(e.code()).ok(); + Err(e.into()) + } + Ok(tracker) => Ok(self.into_writer(tracker).await?), + } } pub async fn push_request( - &self, + mut self, f: impl FnOnce() -> PushRequest, - ) -> Result { - self.events + ) -> anyhow::Result { + let res = self + .events .request(f, self.connection_id, self.request_id) - .await + .await; + match res { + Err(e) => { + self.writer.reset(e.code()).ok(); + Err(e.into()) + } + Ok(tracker) => Ok(self.into_reader(tracker).await?), + } } pub async fn observe_request( - &self, + mut self, f: impl FnOnce() -> ObserveRequest, - ) -> Result { - self.events + ) -> anyhow::Result { + let res = self + .events .request(f, self.connection_id, self.request_id) - .await + .await; + match res { + Err(e) => { + self.writer.reset(e.code()).ok(); + Err(e.into()) + } + Ok(tracker) => Ok(self.into_writer(tracker).await?), + } } fn stats(&self) -> TransferStats { @@ -299,7 +330,8 @@ pub async fn handle_connection( }) .await { - debug!("client not authorized to connect: {cause}"); + connection.close(cause.code(), cause.reason()); + debug!("closing connection: {cause}"); return; } while let Ok(context) = StreamPair::accept(&connection, &progress).await { @@ -323,17 +355,16 @@ async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result< match request { Request::Get(request) => { - let tracker = context.get_request(|| request.clone()).await?; - let mut writer = context.into_writer(tracker).await?; - if handle_get(store, request, &mut writer).await.is_ok() { + let mut writer = context.get_request(|| request.clone()).await?; + let res = handle_get(store, request, &mut writer).await; + if res.is_ok() { writer.transfer_completed().await; } else { writer.transfer_aborted().await; } } Request::GetMany(request) => { - let tracker = context.get_many_request(|| request.clone()).await?; - let mut writer = context.into_writer(tracker).await?; + let mut writer = context.get_many_request(|| request.clone()).await?; if handle_get_many(store, request, &mut writer).await.is_ok() { writer.transfer_completed().await; } else { @@ -341,8 +372,7 @@ async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result< } } Request::Observe(request) => { - let tracker = context.observe_request(|| request.clone()).await?; - let mut writer = context.into_writer(tracker).await?; + let mut writer = context.observe_request(|| request.clone()).await?; if handle_observe(store, request, &mut writer).await.is_ok() { writer.transfer_completed().await; } else { @@ -350,8 +380,7 @@ async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result< } } Request::Push(request) => { - let tracker = context.push_request(|| request.clone()).await?; - let mut reader = context.into_reader(tracker).await?; + let mut reader = context.push_request(|| request.clone()).await?; if handle_push(store, request, &mut reader).await.is_ok() { reader.transfer_completed().await; } else { @@ -464,11 +493,11 @@ pub(crate) async fn send_blob( hash: Hash, ranges: ChunkRanges, writer: &mut ProgressWriter, -) -> api::Result<()> { - Ok(store +) -> ExportBaoResult<()> { + store .export_bao(hash, ranges) .write_quinn_with_progress(&mut writer.inner, &mut writer.context, &hash, index) - .await?) + .await } /// Handle a single push request. diff --git a/src/provider/events.rs b/src/provider/events.rs index f2bddb23c..5a922300a 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -8,7 +8,10 @@ use serde::{Deserialize, Serialize}; use snafu::Snafu; use crate::{ - protocol::{GetManyRequest, GetRequest, ObserveRequest, PushRequest}, + protocol::{ + GetManyRequest, GetRequest, ObserveRequest, PushRequest, ERR_INTERNAL, ERR_LIMIT, + ERR_PERMISSION, + }, provider::{events::irpc_ext::IrpcClientExt, TransferStats}, Hash, }; @@ -82,6 +85,24 @@ pub enum ClientError { }, } +impl ClientError { + pub fn code(&self) -> quinn::VarInt { + match self { + ClientError::RateLimited => ERR_LIMIT, + ClientError::Permission => ERR_PERMISSION, + ClientError::Irpc { .. } => ERR_INTERNAL, + } + } + + pub fn reason(&self) -> &'static [u8] { + match self { + ClientError::RateLimited => b"limit", + ClientError::Permission => b"permission", + ClientError::Irpc { .. } => b"internal", + } + } +} + impl From for ClientError { fn from(value: AbortReason) -> Self { match value { @@ -211,11 +232,14 @@ impl RequestTracker { /// Transfer for index `index` started, size `size` pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { - tx.send(RequestUpdate::Started(TransferStarted { - index, - hash: *hash, - size, - })) + tx.send( + TransferStarted { + index, + hash: *hash, + size, + } + .into(), + ) .await?; } Ok(()) @@ -224,8 +248,7 @@ impl RequestTracker { /// Transfer progress for the previously reported blob, end_offset is the new end offset in bytes. pub async fn transfer_progress(&mut self, len: u64, end_offset: u64) -> ClientResult { if let RequestUpdates::Active(tx) = &mut self.updates { - tx.try_send(RequestUpdate::Progress(TransferProgress { end_offset })) - .await?; + tx.try_send(TransferProgress { end_offset }.into()).await?; } if let Some((throttle, connection_id, request_id)) = &self.throttle { throttle @@ -242,8 +265,7 @@ impl RequestTracker { /// Transfer completed for the previously reported blob. pub async fn transfer_completed(&self, f: impl Fn() -> Box) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { - tx.send(RequestUpdate::Completed(TransferCompleted { stats: f() })) - .await?; + tx.send(TransferCompleted { stats: f() }.into()).await?; } Ok(()) } @@ -251,8 +273,7 @@ impl RequestTracker { /// Transfer aborted for the previously reported blob. pub async fn transfer_aborted(&self, f: impl Fn() -> Box) -> irpc::Result<()> { if let RequestUpdates::Active(tx) = &self.updates { - tx.send(RequestUpdate::Aborted(TransferAborted { stats: f() })) - .await?; + tx.send(TransferAborted { stats: f() }.into()).await?; } Ok(()) } @@ -583,7 +604,7 @@ mod proto { } /// Stream of updates for a single request - #[derive(Debug, Serialize, Deserialize)] + #[derive(Debug, Serialize, Deserialize, derive_more::From)] pub enum RequestUpdate { /// Start of transfer for a blob, mandatory event Started(TransferStarted), From 2d72de0ac64e5e4f00dbb5c8cb712e201693b51b Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 3 Sep 2025 15:15:27 +0200 Subject: [PATCH 17/35] deny --- Cargo.lock | 48 +++++++++++++----------------------------------- 1 file changed, 13 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4068354f7..988d7955a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2094,11 +2094,11 @@ checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" [[package]] name = "matchers" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" dependencies = [ - "regex-automata 0.1.10", + "regex-automata", ] [[package]] @@ -2385,12 +2385,11 @@ dependencies = [ [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "d4a28e057d01f97e61255210fcff094d74ed0466038633e95017f5beb68e4399" dependencies = [ - "overload", - "winapi", + "windows-sys 0.52.0", ] [[package]] @@ -2467,12 +2466,6 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "parking" version = "2.2.1" @@ -2907,7 +2900,7 @@ dependencies = [ "rand 0.9.2", "rand_chacha 0.9.0", "rand_xorshift", - "regex-syntax 0.8.5", + "regex-syntax", "rusty-fork", "tempfile", "unarray", @@ -3151,17 +3144,8 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", + "regex-automata", + "regex-syntax", ] [[package]] @@ -3172,7 +3156,7 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.5", + "regex-syntax", ] [[package]] @@ -3181,12 +3165,6 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" -[[package]] -name = "regex-syntax" -version = "0.6.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - [[package]] name = "regex-syntax" version = "0.8.5" @@ -4245,14 +4223,14 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "matchers", "nu-ansi-term", "once_cell", - "regex", + "regex-automata", "sharded-slab", "smallvec", "thread_local", From 2dac46c3bf6ff0fb3b0d1d5648043826afbdaa64 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 3 Sep 2025 15:25:46 +0200 Subject: [PATCH 18/35] Use async syntax for implementing ProtocolHandler --- src/net_protocol.rs | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/src/net_protocol.rs b/src/net_protocol.rs index 269ef0e14..47cda5344 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -36,7 +36,7 @@ //! # } //! ``` -use std::{fmt::Debug, future::Future, ops::Deref, sync::Arc}; +use std::{fmt::Debug, ops::Deref, sync::Arc}; use iroh::{ endpoint::Connection, @@ -100,25 +100,16 @@ impl BlobsProtocol { } impl ProtocolHandler for BlobsProtocol { - fn accept( - &self, - conn: Connection, - ) -> impl Future> + Send { + async fn accept(&self, conn: Connection) -> std::result::Result<(), AcceptError> { let store = self.store().clone(); let events = self.inner.events.clone(); - - Box::pin(async move { - crate::provider::handle_connection(conn, store, events).await; - Ok(()) - }) + crate::provider::handle_connection(conn, store, events).await; + Ok(()) } - fn shutdown(&self) -> impl Future + Send { - let store = self.store().clone(); - Box::pin(async move { - if let Err(cause) = store.shutdown().await { - error!("error shutting down store: {:?}", cause); - } - }) + async fn shutdown(&self) { + if let Err(cause) = self.store().shutdown().await { + error!("error shutting down store: {:?}", cause); + } } } From a67d7875e3a9107ba210d8e89092fc632207dbc2 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 3 Sep 2025 15:57:39 +0200 Subject: [PATCH 19/35] Use irpc::channel::SendError as default sink error. --- src/api.rs | 11 ++++++++- src/api/blobs.rs | 7 ++++-- src/api/downloader.rs | 9 ++++---- src/api/remote.rs | 42 ++++++++++++++++++---------------- src/provider.rs | 6 ++--- src/provider/events.rs | 52 +++++++++++++++++++++++++----------------- src/util.rs | 13 +++++++---- 7 files changed, 83 insertions(+), 57 deletions(-) diff --git a/src/api.rs b/src/api.rs index a2a34a2db..117c59e25 100644 --- a/src/api.rs +++ b/src/api.rs @@ -30,7 +30,7 @@ pub mod downloader; pub mod proto; pub mod remote; pub mod tags; -use crate::api::proto::WaitIdleRequest; +use crate::{api::proto::WaitIdleRequest, provider::events::ProgressError}; pub use crate::{store::util::Tag, util::temp_tag::TempTag}; pub(crate) type ApiClient = irpc::Client; @@ -97,6 +97,8 @@ pub enum ExportBaoError { ExportBaoIo { source: io::Error }, #[snafu(display("encode error: {source}"))] ExportBaoInner { source: bao_tree::io::EncodeError }, + #[snafu(display("client error: {source}"))] + ClientError { source: ProgressError }, } impl From for Error { @@ -107,6 +109,7 @@ impl From for Error { ExportBaoError::Request { source, .. } => Self::Io(source.into()), ExportBaoError::ExportBaoIo { source, .. } => Self::Io(source), ExportBaoError::ExportBaoInner { source, .. } => Self::Io(source.into()), + ExportBaoError::ClientError { source, .. } => Self::Io(source.into()), } } } @@ -152,6 +155,12 @@ impl From for ExportBaoError { } } +impl From for ExportBaoError { + fn from(value: ProgressError) -> Self { + ClientSnafu.into_error(value) + } +} + pub type ExportBaoResult = std::result::Result; #[derive(Debug, derive_more::Display, derive_more::From, Serialize, Deserialize)] diff --git a/src/api/blobs.rs b/src/api/blobs.rs index 8b618de1f..1822be5b2 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -57,6 +57,7 @@ use super::{ }; use crate::{ api::proto::{BatchRequest, ImportByteStreamUpdate}, + provider::events::ClientResult, store::IROH_BLOCK_SIZE, util::temp_tag::TempTag, BlobFormat, Hash, HashAndFormat, @@ -1112,7 +1113,9 @@ impl ExportBaoProgress { .write_chunk(leaf.data) .await .map_err(io::Error::from)?; - progress.notify_payload_write(index, leaf.offset, len).await; + progress + .notify_payload_write(index, leaf.offset, len) + .await?; } EncodedItem::Done => break, EncodedItem::Error(cause) => return Err(cause.into()), @@ -1158,7 +1161,7 @@ impl ExportBaoProgress { pub(crate) trait WriteProgress { /// Notify the progress writer that a payload write has happened. - async fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize); + async fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize) -> ClientResult; /// Log a write of some other data. fn log_other_write(&mut self, len: usize); diff --git a/src/api/downloader.rs b/src/api/downloader.rs index a2abbd7ea..1db1e6f07 100644 --- a/src/api/downloader.rs +++ b/src/api/downloader.rs @@ -3,7 +3,6 @@ use std::{ collections::{HashMap, HashSet}, fmt::Debug, future::{Future, IntoFuture}, - io, sync::Arc, }; @@ -113,7 +112,7 @@ async fn handle_download_impl( SplitStrategy::Split => handle_download_split_impl(store, pool, request, tx).await?, SplitStrategy::None => match request.request { FiniteRequest::Get(get) => { - let sink = IrpcSenderRefSink(tx).with_map_err(io::Error::other); + let sink = IrpcSenderRefSink(tx); execute_get(&pool, Arc::new(get), &request.providers, &store, sink).await?; } FiniteRequest::GetMany(_) => { @@ -144,7 +143,7 @@ async fn handle_download_split_impl( let (tx, rx) = tokio::sync::mpsc::channel::<(usize, DownloadProgessItem)>(16); progress_tx.send(rx).await.ok(); let sink = TokioMpscSenderSink(tx) - .with_map_err(io::Error::other) + .with_map_err(|_| irpc::channel::SendError::ReceiverClosed) .with_map(move |x| (id, x)); let res = execute_get(&pool, Arc::new(request), &providers, &store, sink).await; (hash, res) @@ -375,7 +374,7 @@ async fn split_request<'a>( providers: &Arc, pool: &ConnectionPool, store: &Store, - progress: impl Sink, + progress: impl Sink, ) -> anyhow::Result + Send + 'a>> { Ok(match request { FiniteRequest::Get(req) => { @@ -431,7 +430,7 @@ async fn execute_get( request: Arc, providers: &Arc, store: &Store, - mut progress: impl Sink, + mut progress: impl Sink, ) -> anyhow::Result<()> { let remote = store.remote(); let mut providers = providers.find_providers(request.content()); diff --git a/src/api/remote.rs b/src/api/remote.rs index 623200900..3d8a3a817 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -18,6 +18,7 @@ use crate::{ GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType, MAX_MESSAGE_SIZE, }, + provider::events::{ClientResult, ProgressError}, util::sink::{Sink, TokioMpscSenderSink}, }; @@ -478,9 +479,7 @@ impl Remote { let content = content.into(); let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); - let sink = TokioMpscSenderSink(tx) - .with_map(GetProgressItem::Progress) - .with_map_err(io::Error::other); + let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress); let this = self.clone(); let fut = async move { let res = this.fetch_sink(conn, content, sink).await.into(); @@ -503,7 +502,7 @@ impl Remote { &self, mut conn: impl GetConnection, content: impl Into, - progress: impl Sink, + progress: impl Sink, ) -> GetResult { let content = content.into(); let local = self @@ -556,9 +555,7 @@ impl Remote { pub fn execute_push(&self, conn: Connection, request: PushRequest) -> PushProgress { let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); - let sink = TokioMpscSenderSink(tx) - .with_map(PushProgressItem::Progress) - .with_map_err(io::Error::other); + let sink = TokioMpscSenderSink(tx).with_map(PushProgressItem::Progress); let this = self.clone(); let fut = async move { let res = this.execute_push_sink(conn, request, sink).await.into(); @@ -577,7 +574,7 @@ impl Remote { &self, conn: Connection, request: PushRequest, - progress: impl Sink, + progress: impl Sink, ) -> anyhow::Result { let hash = request.hash; debug!(%hash, "pushing"); @@ -632,9 +629,7 @@ impl Remote { pub fn execute_get_with_opts(&self, conn: Connection, request: GetRequest) -> GetProgress { let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); - let sink = TokioMpscSenderSink(tx) - .with_map(GetProgressItem::Progress) - .with_map_err(io::Error::other); + let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress); let this = self.clone(); let fut = async move { let res = this.execute_get_sink(&conn, request, sink).await.into(); @@ -658,7 +653,7 @@ impl Remote { &self, conn: &Connection, request: GetRequest, - mut progress: impl Sink, + mut progress: impl Sink, ) -> GetResult { let store = self.store(); let root = request.hash; @@ -721,9 +716,7 @@ impl Remote { pub fn execute_get_many(&self, conn: Connection, request: GetManyRequest) -> GetProgress { let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); - let sink = TokioMpscSenderSink(tx) - .with_map(GetProgressItem::Progress) - .with_map_err(io::Error::other); + let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress); let this = self.clone(); let fut = async move { let res = this.execute_get_many_sink(conn, request, sink).await.into(); @@ -747,7 +740,7 @@ impl Remote { &self, conn: Connection, request: GetManyRequest, - mut progress: impl Sink, + mut progress: impl Sink, ) -> GetResult { let store = self.store(); let hash_seq = request.hashes.iter().copied().collect::(); @@ -884,7 +877,7 @@ async fn get_blob_ranges_impl( header: AtBlobHeader, hash: Hash, store: &Store, - mut progress: impl Sink, + mut progress: impl Sink, ) -> GetResult { let (mut content, size) = header.next().await?; let Some(size) = NonZeroU64::new(size) else { @@ -1048,11 +1041,20 @@ struct StreamContext { impl WriteProgress for StreamContext where - S: Sink, + S: Sink, { - async fn notify_payload_write(&mut self, _index: u64, _offset: u64, len: usize) { + async fn notify_payload_write( + &mut self, + _index: u64, + _offset: u64, + len: usize, + ) -> ClientResult { self.payload_bytes_sent += len as u64; - self.sender.send(self.payload_bytes_sent).await.ok(); + self.sender + .send(self.payload_bytes_sent) + .await + .map_err(|e| ProgressError::Internal { source: e.into() })?; + Ok(()) } fn log_other_write(&mut self, _len: usize) {} diff --git a/src/provider.rs b/src/provider.rs index 49b57e13a..0134169c6 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -25,7 +25,7 @@ use crate::{ }, hashseq::HashSeq, protocol::{GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request}, - provider::events::{ClientConnected, ConnectionClosed, RequestTracker}, + provider::events::{ClientConnected, ClientResult, ConnectionClosed, RequestTracker}, Hash, }; pub mod events; @@ -264,11 +264,11 @@ impl WriterContext { } impl WriteProgress for WriterContext { - async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) { + async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult { let len = len as u64; let end_offset = offset + len; self.payload_bytes_written += len; - self.tracker.transfer_progress(len, end_offset).await.ok(); + self.tracker.transfer_progress(len, end_offset).await } fn log_other_write(&mut self, len: usize) { diff --git a/src/provider/events.rs b/src/provider/events.rs index 5a922300a..fff800dc9 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -1,4 +1,4 @@ -use std::{fmt::Debug, ops::Deref}; +use std::{fmt::Debug, io, ops::Deref}; use irpc::{ channel::{mpsc, none::NoSender, oneshot}, @@ -76,60 +76,70 @@ pub enum AbortReason { } #[derive(Debug, Snafu)] -pub enum ClientError { - RateLimited, +pub enum ProgressError { + Limit, Permission, #[snafu(transparent)] - Irpc { + Internal { source: irpc::Error, }, } -impl ClientError { +impl From for io::Error { + fn from(value: ProgressError) -> Self { + match value { + ProgressError::Limit => io::ErrorKind::QuotaExceeded.into(), + ProgressError::Permission => io::ErrorKind::PermissionDenied.into(), + ProgressError::Internal { source } => source.into(), + } + } +} + +impl ProgressError { pub fn code(&self) -> quinn::VarInt { match self { - ClientError::RateLimited => ERR_LIMIT, - ClientError::Permission => ERR_PERMISSION, - ClientError::Irpc { .. } => ERR_INTERNAL, + ProgressError::Limit => ERR_LIMIT, + ProgressError::Permission => ERR_PERMISSION, + ProgressError::Internal { .. } => ERR_INTERNAL, } } pub fn reason(&self) -> &'static [u8] { match self { - ClientError::RateLimited => b"limit", - ClientError::Permission => b"permission", - ClientError::Irpc { .. } => b"internal", + ProgressError::Limit => b"limit", + ProgressError::Permission => b"permission", + ProgressError::Internal { .. } => b"internal", } } } -impl From for ClientError { +impl From for ProgressError { fn from(value: AbortReason) -> Self { match value { - AbortReason::RateLimited => ClientError::RateLimited, - AbortReason::Permission => ClientError::Permission, + AbortReason::RateLimited => ProgressError::Limit, + AbortReason::Permission => ProgressError::Permission, } } } -impl From for ClientError { +impl From for ProgressError { fn from(value: irpc::channel::RecvError) -> Self { - ClientError::Irpc { + ProgressError::Internal { source: value.into(), } } } -impl From for ClientError { +impl From for ProgressError { fn from(value: irpc::channel::SendError) -> Self { - ClientError::Irpc { + ProgressError::Internal { source: value.into(), } } } pub type EventResult = Result<(), AbortReason>; -pub type ClientResult = Result<(), ClientError>; +pub type ClientResult = Result<(), ProgressError>; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct EventMask { @@ -407,7 +417,7 @@ impl EventSender { f: impl FnOnce() -> Req, connection_id: u64, request_id: u64, - ) -> Result + ) -> Result where ProviderProto: From>, ProviderMessage: From, ProviderProto>>, @@ -466,7 +476,7 @@ impl EventSender { RequestUpdates::Active(tx) } RequestMode::Disabled => { - return Err(ClientError::Permission); + return Err(ProgressError::Permission); } _ => RequestUpdates::None, }, diff --git a/src/util.rs b/src/util.rs index 3fdaacbca..40abf0343 100644 --- a/src/util.rs +++ b/src/util.rs @@ -363,7 +363,7 @@ pub(crate) mod outboard_with_progress { } pub(crate) mod sink { - use std::{future::Future, io}; + use std::future::Future; use irpc::RpcMessage; @@ -433,10 +433,13 @@ pub(crate) mod sink { pub struct TokioMpscSenderSink(pub tokio::sync::mpsc::Sender); impl Sink for TokioMpscSenderSink { - type Error = tokio::sync::mpsc::error::SendError; + type Error = irpc::channel::SendError; async fn send(&mut self, value: T) -> std::result::Result<(), Self::Error> { - self.0.send(value).await + self.0 + .send(value) + .await + .map_err(|_| irpc::channel::SendError::ReceiverClosed) } } @@ -483,10 +486,10 @@ pub(crate) mod sink { pub struct Drain; impl Sink for Drain { - type Error = io::Error; + type Error = irpc::channel::SendError; async fn send(&mut self, _offset: T) -> std::result::Result<(), Self::Error> { - io::Result::Ok(()) + Ok(()) } } } From 546f57e90af8518c1c70e06078199987a5fc76d1 Mon Sep 17 00:00:00 2001 From: Frando Date: Wed, 3 Sep 2025 15:05:28 +0200 Subject: [PATCH 20/35] fixup --- examples/limit.rs | 18 +++++++++++------- examples/random_store.rs | 6 +++--- src/tests.rs | 13 ++----------- 3 files changed, 16 insertions(+), 21 deletions(-) diff --git a/examples/limit.rs b/examples/limit.rs index 830574fcc..e72f9be59 100644 --- a/examples/limit.rs +++ b/examples/limit.rs @@ -271,7 +271,7 @@ async fn main() -> Result<()> { let store = MemStore::new(); let hashes = add_paths(&store, paths).await?; let events = limit_by_node_id(allowed_nodes.clone()); - let (router, addr) = setup(MemStore::new(), events).await?; + let (router, addr) = setup(store, events).await?; for (path, hash) in hashes { let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); @@ -299,12 +299,16 @@ async fn main() -> Result<()> { } } - let events = limit_by_hash(allowed_hashes); - let (router, addr) = setup(MemStore::new(), events).await?; + let events = limit_by_hash(allowed_hashes.clone()); + let (router, addr) = setup(store, events).await?; - for (i, (path, hash)) in hashes.iter().enumerate() { + for (path, hash) in hashes.iter() { let ticket = BlobTicket::new(addr.clone(), *hash, BlobFormat::Raw); - let permitted = if i == 0 { "" } else { "limited" }; + let permitted = if allowed_hashes.contains(hash) { + "allowed" + } else { + "forbidden" + }; println!("{}: {ticket} ({permitted})", path.display()); } tokio::signal::ctrl_c().await?; @@ -314,7 +318,7 @@ async fn main() -> Result<()> { let store = MemStore::new(); let hashes = add_paths(&store, paths).await?; let events = throttle(delay_ms); - let (router, addr) = setup(MemStore::new(), events).await?; + let (router, addr) = setup(store, events).await?; for (path, hash) in hashes { let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); println!("{}: {ticket}", path.display()); @@ -329,7 +333,7 @@ async fn main() -> Result<()> { let store = MemStore::new(); let hashes = add_paths(&store, paths).await?; let events = limit_max_connections(max_connections); - let (router, addr) = setup(MemStore::new(), events).await?; + let (router, addr) = setup(store, events).await?; for (path, hash) in hashes { let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); println!("{}: {ticket}", path.display()); diff --git a/examples/random_store.rs b/examples/random_store.rs index c4c30348b..d3f9a0fc4 100644 --- a/examples/random_store.rs +++ b/examples/random_store.rs @@ -14,7 +14,7 @@ use iroh_blobs::{ use irpc::RpcMessage; use n0_future::StreamExt; use rand::{rngs::StdRng, Rng, SeedableRng}; -use tokio::{signal::ctrl_c, sync::mpsc}; +use tokio::signal::ctrl_c; use tracing::info; #[derive(Parser, Debug)] @@ -102,7 +102,7 @@ pub fn get_or_generate_secret_key() -> Result { } pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, EventSender) { - let (tx, mut rx) = mpsc::channel(100); + let (tx, mut rx) = EventSender::channel(100, EventMask::ALL_READONLY); fn dump_updates(mut rx: irpc::channel::mpsc::Receiver) { tokio::spawn(async move { while let Ok(Some(update)) = rx.recv().await { @@ -176,7 +176,7 @@ pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, E } } }); - (dump_task, EventSender::new(tx, EventMask::ALL_READONLY)) + (dump_task, tx) } #[tokio::main] diff --git a/src/tests.rs b/src/tests.rs index dc38eb436..0ef0c027c 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -342,7 +342,7 @@ fn event_handler( allowed_nodes: impl IntoIterator, ) -> (EventSender, watch::Receiver, AbortOnDropHandle<()>) { let (count_tx, count_rx) = tokio::sync::watch::channel(0usize); - let (events_tx, mut events_rx) = mpsc::channel::(16); + let (events_tx, mut events_rx) = EventSender::channel(16, EventMask::ALL_READONLY); let allowed_nodes = allowed_nodes.into_iter().collect::>(); let task = AbortOnDropHandle::new(tokio::task::spawn(async move { while let Some(event) = events_rx.recv().await { @@ -370,16 +370,7 @@ fn event_handler( } } })); - ( - EventSender::new( - events_tx, - EventMask { - ..EventMask::ALL_READONLY - }, - ), - count_rx, - task, - ) + (events_tx, count_rx, task) } async fn two_nodes_push_blobs( From f399e2bc44089dc398b9ba7d7572fdbb556443fb Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 3 Sep 2025 16:19:38 +0200 Subject: [PATCH 21/35] Remove map_err that isn't needed anymore --- src/api/downloader.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/api/downloader.rs b/src/api/downloader.rs index 1db1e6f07..3555eca9c 100644 --- a/src/api/downloader.rs +++ b/src/api/downloader.rs @@ -142,9 +142,7 @@ async fn handle_download_split_impl( let hash = request.hash; let (tx, rx) = tokio::sync::mpsc::channel::<(usize, DownloadProgessItem)>(16); progress_tx.send(rx).await.ok(); - let sink = TokioMpscSenderSink(tx) - .with_map_err(|_| irpc::channel::SendError::ReceiverClosed) - .with_map(move |x| (id, x)); + let sink = TokioMpscSenderSink(tx).with_map(move |x| (id, x)); let res = execute_get(&pool, Arc::new(request), &providers, &store, sink).await; (hash, res) } From 3f0a661959456e80df279a996365034f025c23af Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 4 Sep 2025 09:49:37 +0200 Subject: [PATCH 22/35] Refactor the GetError to be just a list of things that can go wrong. Also make the whole get fsm generic so it can be used with an arbitrary stream, not just a quinn/iroh RecvStream. --- src/api/remote.rs | 12 +- src/get.rs | 353 ++++++++++++++++++++------------------------ src/get/error.rs | 359 +++++++++++---------------------------------- src/get/request.rs | 2 +- 4 files changed, 254 insertions(+), 472 deletions(-) diff --git a/src/api/remote.rs b/src/api/remote.rs index 3d8a3a817..0d729081e 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -8,12 +8,16 @@ use n0_future::{io, Stream, StreamExt}; use n0_snafu::SpanTrace; use nested_enum_utils::common_fields; use ref_cast::RefCast; -use snafu::{Backtrace, IntoError, Snafu}; +use snafu::{Backtrace, IntoError, ResultExt, Snafu}; use super::blobs::{Bitfield, ExportBaoOptions}; use crate::{ api::{blobs::WriteProgress, ApiClient}, - get::{fsm::DecodeError, BadRequestSnafu, GetError, GetResult, LocalFailureSnafu, Stats}, + get::{ + fsm::DecodeError, + get_error::{BadRequestSnafu, LocalFailureSnafu}, + GetError, GetResult, Stats, + }, protocol::{ GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType, MAX_MESSAGE_SIZE, @@ -508,7 +512,7 @@ impl Remote { let local = self .local(content) .await - .map_err(|e| LocalFailureSnafu.into_error(e.into()))?; + .map_err(|e: anyhow::Error| LocalFailureSnafu.into_error(e.into()))?; if local.is_complete() { return Ok(Default::default()); } @@ -685,7 +689,7 @@ impl Remote { .await .map_err(|e| LocalFailureSnafu.into_error(e.into()))?, ) - .map_err(|source| BadRequestSnafu.into_error(source.into()))?; + .context(BadRequestSnafu)?; // let mut hash_seq = LazyHashSeq::new(store.blobs().clone(), root); loop { let at_start_child = match next_child { diff --git a/src/get.rs b/src/get.rs index 049ef4855..6032857e4 100644 --- a/src/get.rs +++ b/src/get.rs @@ -17,7 +17,6 @@ //! //! [iroh]: https://docs.rs/iroh use std::{ - error::Error, fmt::{self, Debug}, time::{Duration, Instant}, }; @@ -25,8 +24,8 @@ use std::{ use anyhow::Result; use bao_tree::{io::fsm::BaoContentItem, ChunkNum}; use fsm::RequestCounters; -use iroh::endpoint::{self, RecvStream, SendStream}; -use iroh_io::TokioStreamReader; +use iroh::endpoint::{RecvStream, SendStream}; +use iroh_io::{TokioStreamReader, TokioStreamWriter}; use n0_snafu::SpanTrace; use nested_enum_utils::common_fields; use serde::{Deserialize, Serialize}; @@ -37,10 +36,11 @@ use crate::{protocol::ChunkRangesSeq, store::IROH_BLOCK_SIZE, Hash}; mod error; pub mod request; -pub(crate) use error::{BadRequestSnafu, LocalFailureSnafu}; +pub(crate) use error::get_error; pub use error::{GetError, GetResult}; -type WrappedRecvStream = TokioStreamReader; +type DefaultReader = TokioStreamReader; +type DefaultWriter = TokioStreamWriter; /// Stats about the transfer. #[derive( @@ -96,11 +96,11 @@ pub mod fsm { }; use derive_more::From; use iroh::endpoint::Connection; - use iroh_io::{AsyncSliceWriter, AsyncStreamReader, TokioStreamReader}; + use iroh_io::{AsyncSliceWriter, AsyncStreamReader, AsyncStreamWriter, TokioStreamReader}; use super::*; use crate::{ - get::error::BadRequestSnafu, + get::get_error::BadRequestSnafu, protocol::{ GetManyRequest, GetRequest, NonEmptyRequestRangeSpecIter, Request, MAX_MESSAGE_SIZE, }, @@ -130,16 +130,22 @@ pub mod fsm { counters: RequestCounters, ) -> std::result::Result, GetError> { let start = Instant::now(); - let (mut writer, reader) = connection.open_bi().await?; + let (writer, reader) = connection + .open_bi() + .await + .map_err(|e| OpenSnafu.into_error(e.into()))?; + let reader = TokioStreamReader::new(reader); + let mut writer = TokioStreamWriter(writer); let request = Request::GetMany(request); let request_bytes = postcard::to_stdvec(&request) .map_err(|source| BadRequestSnafu.into_error(source.into()))?; - writer.write_all(&request_bytes).await?; - writer.finish()?; + writer + .write_bytes(request_bytes.into()) + .await + .context(connected_next_error::WriteSnafu)?; let Request::GetMany(request) = request else { unreachable!(); }; - let reader = TokioStreamReader::new(reader); let mut ranges_iter = RangesIter::new(request.ranges.clone()); let first_item = ranges_iter.next(); let misc = Box::new(Misc { @@ -214,10 +220,15 @@ pub mod fsm { } /// Initiate a new bidi stream to use for the get response - pub async fn next(self) -> Result { + pub async fn next(self) -> Result { let start = Instant::now(); - let (writer, reader) = self.connection.open_bi().await?; + let (writer, reader) = self + .connection + .open_bi() + .await + .map_err(|e| OpenSnafu.into_error(e.into()))?; let reader = TokioStreamReader::new(reader); + let writer = TokioStreamWriter(writer); Ok(AtConnected { start, reader, @@ -228,25 +239,41 @@ pub mod fsm { } } + /// Error that you can get from [`AtConnected::next`] + #[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: SpanTrace, + })] + #[allow(missing_docs)] + #[derive(Debug, Snafu)] + #[non_exhaustive] + pub enum InitialNextError { + Open { source: io::Error }, + } + /// State of the get response machine after the handshake has been sent #[derive(Debug)] - pub struct AtConnected { + pub struct AtConnected< + R: AsyncStreamReader = DefaultReader, + W: AsyncStreamWriter = DefaultWriter, + > { start: Instant, - reader: WrappedRecvStream, - writer: SendStream, + reader: R, + writer: W, request: GetRequest, counters: RequestCounters, } /// Possible next states after the handshake has been sent #[derive(Debug, From)] - pub enum ConnectedNext { + pub enum ConnectedNext { /// First response is either a collection or a single blob - StartRoot(AtStartRoot), + StartRoot(AtStartRoot), /// First response is a child - StartChild(AtStartChild), + StartChild(AtStartChild), /// Request is empty - Closing(AtClosing), + Closing(AtClosing), } /// Error that you can get from [`AtConnected::next`] @@ -257,6 +284,7 @@ pub mod fsm { })] #[allow(missing_docs)] #[derive(Debug, Snafu)] + #[snafu(module)] #[non_exhaustive] pub enum ConnectedNextError { /// Error when serializing the request @@ -267,23 +295,17 @@ pub mod fsm { RequestTooBig {}, /// Error when writing the request to the [`SendStream`]. #[snafu(display("write: {source}"))] - Write { source: quinn::WriteError }, - /// Quic connection is closed. - #[snafu(display("closed"))] - Closed { source: quinn::ClosedStream }, - /// A generic io error - #[snafu(transparent)] - Io { source: io::Error }, + Write { source: io::Error }, } - impl AtConnected { + impl AtConnected { /// Send the request and move to the next state /// /// The next state will be either `StartRoot` or `StartChild` depending on whether /// the request requests part of the collection or not. /// /// If the request is empty, this can also move directly to `Finished`. - pub async fn next(self) -> Result { + pub async fn next(self) -> Result, ConnectedNextError> { let Self { start, reader, @@ -295,23 +317,28 @@ pub mod fsm { counters.other_bytes_written += { debug!("sending request"); let wrapped = Request::Get(request); - let request_bytes = postcard::to_stdvec(&wrapped).context(PostcardSerSnafu)?; + let request_bytes = postcard::to_stdvec(&wrapped) + .context(connected_next_error::PostcardSerSnafu)?; let Request::Get(x) = wrapped else { unreachable!(); }; request = x; if request_bytes.len() > MAX_MESSAGE_SIZE { - return Err(RequestTooBigSnafu.build()); + return Err(connected_next_error::RequestTooBigSnafu.build()); } // write the request itself - writer.write_all(&request_bytes).await.context(WriteSnafu)?; - request_bytes.len() as u64 + let len = request_bytes.len() as u64; + writer + .write_bytes(request_bytes.into()) + .await + .context(connected_next_error::WriteSnafu)?; + len }; // 2. Finish writing before expecting a response - writer.finish().context(ClosedSnafu)?; + drop(writer); let hash = request.hash; let ranges_iter = RangesIter::new(request.ranges); @@ -348,23 +375,23 @@ pub mod fsm { /// State of the get response when we start reading a collection #[derive(Debug)] - pub struct AtStartRoot { + pub struct AtStartRoot> { ranges: ChunkRanges, - reader: TokioStreamReader, + reader: R, misc: Box, hash: Hash, } /// State of the get response when we start reading a child #[derive(Debug)] - pub struct AtStartChild { + pub struct AtStartChild> { ranges: ChunkRanges, - reader: TokioStreamReader, + reader: R, misc: Box, offset: u64, } - impl AtStartChild { + impl AtStartChild { /// The offset of the child we are currently reading /// /// This must be used to determine the hash needed to call next. @@ -382,7 +409,7 @@ pub mod fsm { /// Go into the next state, reading the header /// /// This requires passing in the hash of the child for validation - pub fn next(self, hash: Hash) -> AtBlobHeader { + pub fn next(self, hash: Hash) -> AtBlobHeader { AtBlobHeader { reader: self.reader, ranges: self.ranges, @@ -396,12 +423,12 @@ pub mod fsm { /// This is used if you know that there are no more children from having /// read the collection, or when you want to stop reading the response /// early. - pub fn finish(self) -> AtClosing { + pub fn finish(self) -> AtClosing { AtClosing::new(self.misc, self.reader, false) } } - impl AtStartRoot { + impl AtStartRoot { /// The ranges we have requested for the child pub fn ranges(&self) -> &ChunkRanges { &self.ranges @@ -415,7 +442,7 @@ pub mod fsm { /// Go into the next state, reading the header /// /// For the collection we already know the hash, since it was part of the request - pub fn next(self) -> AtBlobHeader { + pub fn next(self) -> AtBlobHeader { AtBlobHeader { reader: self.reader, ranges: self.ranges, @@ -425,16 +452,16 @@ pub mod fsm { } /// Finish the get response without reading further - pub fn finish(self) -> AtClosing { + pub fn finish(self) -> AtClosing { AtClosing::new(self.misc, self.reader, false) } } /// State before reading a size header #[derive(Debug)] - pub struct AtBlobHeader { + pub struct AtBlobHeader> { ranges: ChunkRanges, - reader: TokioStreamReader, + reader: R, misc: Box, hash: Hash, } @@ -447,18 +474,16 @@ pub mod fsm { })] #[non_exhaustive] #[derive(Debug, Snafu)] + #[snafu(module)] pub enum AtBlobHeaderNextError { /// Eof when reading the size header /// /// This indicates that the provider does not have the requested data. #[snafu(display("not found"))] NotFound {}, - /// Quinn read error when reading the size header - #[snafu(display("read: {source}"))] - EndpointRead { source: endpoint::ReadError }, /// Generic io error #[snafu(display("io: {source}"))] - Io { source: io::Error }, + Read { source: io::Error }, } impl From for io::Error { @@ -467,25 +492,19 @@ pub mod fsm { AtBlobHeaderNextError::NotFound { .. } => { io::Error::new(io::ErrorKind::UnexpectedEof, cause) } - AtBlobHeaderNextError::EndpointRead { source, .. } => source.into(), - AtBlobHeaderNextError::Io { source, .. } => source, + AtBlobHeaderNextError::Read { source, .. } => source, } } } - impl AtBlobHeader { + impl AtBlobHeader { /// Read the size header, returning it and going into the `Content` state. - pub async fn next(mut self) -> Result<(AtBlobContent, u64), AtBlobHeaderNextError> { + pub async fn next(mut self) -> Result<(AtBlobContent, u64), AtBlobHeaderNextError> { let size = self.reader.read::<8>().await.map_err(|cause| { if cause.kind() == io::ErrorKind::UnexpectedEof { - NotFoundSnafu.build() - } else if let Some(e) = cause - .get_ref() - .and_then(|x| x.downcast_ref::()) - { - EndpointReadSnafu.into_error(e.clone()) + at_blob_header_next_error::NotFoundSnafu.build() } else { - IoSnafu.into_error(cause) + at_blob_header_next_error::ReadSnafu.into_error(cause) } })?; self.misc.other_bytes_read += 8; @@ -506,7 +525,7 @@ pub mod fsm { } /// Drain the response and throw away the result - pub async fn drain(self) -> result::Result { + pub async fn drain(self) -> result::Result, DecodeError> { let (content, _size) = self.next().await?; content.drain().await } @@ -517,7 +536,7 @@ pub mod fsm { /// concatenate the ranges that were requested. pub async fn concatenate_into_vec( self, - ) -> result::Result<(AtEndBlob, Vec), DecodeError> { + ) -> result::Result<(AtEndBlob, Vec), DecodeError> { let (content, _size) = self.next().await?; content.concatenate_into_vec().await } @@ -526,7 +545,7 @@ pub mod fsm { pub async fn write_all( self, data: D, - ) -> result::Result { + ) -> result::Result, DecodeError> { let (content, _size) = self.next().await?; let res = content.write_all(data).await?; Ok(res) @@ -540,7 +559,7 @@ pub mod fsm { self, outboard: Option, data: D, - ) -> result::Result + ) -> result::Result, DecodeError> where D: AsyncSliceWriter, O: OutboardMut, @@ -568,8 +587,8 @@ pub mod fsm { /// State while we are reading content #[derive(Debug)] - pub struct AtBlobContent { - stream: ResponseDecoder, + pub struct AtBlobContent> { + stream: ResponseDecoder, misc: Box, } @@ -603,6 +622,7 @@ pub mod fsm { })] #[non_exhaustive] #[derive(Debug, Snafu)] + #[snafu(module)] pub enum DecodeError { /// A chunk was not found or invalid, so the provider stopped sending data #[snafu(display("not found"))] @@ -621,24 +641,25 @@ pub mod fsm { LeafHashMismatch { num: ChunkNum }, /// Error when reading from the stream #[snafu(display("read: {source}"))] - Read { source: endpoint::ReadError }, + Read { source: io::Error }, /// A generic io error #[snafu(display("io: {source}"))] - DecodeIo { source: io::Error }, + Write { source: io::Error }, } impl DecodeError { pub(crate) fn leaf_hash_mismatch(num: ChunkNum) -> Self { - LeafHashMismatchSnafu { num }.build() + decode_error::LeafHashMismatchSnafu { num }.build() } } impl From for DecodeError { fn from(cause: AtBlobHeaderNextError) -> Self { match cause { - AtBlobHeaderNextError::NotFound { .. } => ChunkNotFoundSnafu.build(), - AtBlobHeaderNextError::EndpointRead { source, .. } => ReadSnafu.into_error(source), - AtBlobHeaderNextError::Io { source, .. } => DecodeIoSnafu.into_error(source), + AtBlobHeaderNextError::NotFound { .. } => decode_error::ChunkNotFoundSnafu.build(), + AtBlobHeaderNextError::Read { source, .. } => { + decode_error::ReadSnafu.into_error(source) + } } } } @@ -653,58 +674,49 @@ pub mod fsm { io::Error::new(io::ErrorKind::UnexpectedEof, cause) } DecodeError::Read { source, .. } => source.into(), - DecodeError::DecodeIo { source, .. } => source, + DecodeError::Write { source, .. } => source, _ => io::Error::other(cause), } } } - impl From for DecodeError { - fn from(value: io::Error) -> Self { - DecodeIoSnafu.into_error(value) - } - } - impl From for DecodeError { fn from(value: bao_tree::io::DecodeError) -> Self { match value { bao_tree::io::DecodeError::ParentNotFound(x) => { - ParentNotFoundSnafu { node: x }.build() + decode_error::ParentNotFoundSnafu { node: x }.build() + } + bao_tree::io::DecodeError::LeafNotFound(x) => { + decode_error::LeafNotFoundSnafu { num: x }.build() } - bao_tree::io::DecodeError::LeafNotFound(x) => LeafNotFoundSnafu { num: x }.build(), bao_tree::io::DecodeError::ParentHashMismatch(node) => { - ParentHashMismatchSnafu { node }.build() + decode_error::ParentHashMismatchSnafu { node }.build() } bao_tree::io::DecodeError::LeafHashMismatch(chunk) => { - LeafHashMismatchSnafu { num: chunk }.build() - } - bao_tree::io::DecodeError::Io(cause) => { - if let Some(inner) = cause.get_ref() { - if let Some(e) = inner.downcast_ref::() { - ReadSnafu.into_error(e.clone()) - } else { - DecodeIoSnafu.into_error(cause) - } - } else { - DecodeIoSnafu.into_error(cause) - } + decode_error::LeafHashMismatchSnafu { num: chunk }.build() } + bao_tree::io::DecodeError::Io(cause) => decode_error::ReadSnafu.into_error(cause), } } } /// The next state after reading a content item #[derive(Debug, From)] - pub enum BlobContentNext { + pub enum BlobContentNext { /// We expect more content - More((AtBlobContent, result::Result)), + More( + ( + AtBlobContent, + result::Result, + ), + ), /// We are done with this blob - Done(AtEndBlob), + Done(AtEndBlob), } - impl AtBlobContent { + impl AtBlobContent { /// Read the next item, either content, an error, or the end of the blob - pub async fn next(self) -> BlobContentNext { + pub async fn next(self) -> BlobContentNext { match self.stream.next().await { ResponseDecoderNext::More((stream, res)) => { let mut next = Self { stream, ..self }; @@ -751,7 +763,7 @@ pub mod fsm { } /// Drain the response and throw away the result - pub async fn drain(self) -> result::Result { + pub async fn drain(self) -> result::Result, DecodeError> { let mut content = self; loop { match content.next().await { @@ -769,7 +781,7 @@ pub mod fsm { /// Concatenate the entire response into a vec pub async fn concatenate_into_vec( self, - ) -> result::Result<(AtEndBlob, Vec), DecodeError> { + ) -> result::Result<(AtEndBlob, Vec), DecodeError> { let mut res = Vec::with_capacity(1024); let mut curr = self; let done = loop { @@ -797,7 +809,7 @@ pub mod fsm { self, mut outboard: Option, mut data: D, - ) -> result::Result + ) -> result::Result, DecodeError> where D: AsyncSliceWriter, O: OutboardMut, @@ -810,11 +822,16 @@ pub mod fsm { match item? { BaoContentItem::Parent(parent) => { if let Some(outboard) = outboard.as_mut() { - outboard.save(parent.node, &parent.pair).await?; + outboard + .save(parent.node, &parent.pair) + .await + .map_err(|e| decode_error::WriteSnafu.into_error(e))?; } } BaoContentItem::Leaf(leaf) => { - data.write_bytes_at(leaf.offset, leaf.data).await?; + data.write_bytes_at(leaf.offset, leaf.data) + .await + .map_err(|e| decode_error::WriteSnafu.into_error(e))?; } } } @@ -826,7 +843,7 @@ pub mod fsm { } /// Write the entire blob to a slice writer. - pub async fn write_all(self, mut data: D) -> result::Result + pub async fn write_all(self, mut data: D) -> result::Result, DecodeError> where D: AsyncSliceWriter, { @@ -838,7 +855,9 @@ pub mod fsm { match item? { BaoContentItem::Parent(_) => {} BaoContentItem::Leaf(leaf) => { - data.write_bytes_at(leaf.offset, leaf.data).await?; + data.write_bytes_at(leaf.offset, leaf.data) + .await + .map_err(|e| decode_error::WriteSnafu.into_error(e))?; } } } @@ -850,30 +869,30 @@ pub mod fsm { } /// Immediately finish the get response without reading further - pub fn finish(self) -> AtClosing { + pub fn finish(self) -> AtClosing { AtClosing::new(self.misc, self.stream.finish(), false) } } /// State after we have read all the content for a blob #[derive(Debug)] - pub struct AtEndBlob { - stream: WrappedRecvStream, + pub struct AtEndBlob> { + stream: R, misc: Box, } /// The next state after the end of a blob #[derive(Debug, From)] - pub enum EndBlobNext { + pub enum EndBlobNext> { /// Response is expected to have more children - MoreChildren(AtStartChild), + MoreChildren(AtStartChild), /// No more children expected - Closing(AtClosing), + Closing(AtClosing), } - impl AtEndBlob { + impl AtEndBlob { /// Read the next child, or finish - pub fn next(mut self) -> EndBlobNext { + pub fn next(mut self) -> EndBlobNext { if let Some((offset, ranges)) = self.misc.ranges_iter.next() { AtStartChild { reader: self.stream, @@ -890,14 +909,14 @@ pub mod fsm { /// State when finishing the get response #[derive(Debug)] - pub struct AtClosing { + pub struct AtClosing> { misc: Box, - reader: WrappedRecvStream, + reader: R, check_extra_data: bool, } - impl AtClosing { - fn new(misc: Box, reader: WrappedRecvStream, check_extra_data: bool) -> Self { + impl AtClosing { + fn new(misc: Box, reader: R, check_extra_data: bool) -> Self { Self { misc, reader, @@ -906,17 +925,14 @@ pub mod fsm { } /// Finish the get response, returning statistics - pub async fn next(self) -> result::Result { + pub async fn next(self) -> result::Result { // Shut down the stream - let reader = self.reader; - let mut reader = reader.into_inner(); + let mut reader = self.reader; if self.check_extra_data { - if let Some(chunk) = reader.read_chunk(8, false).await? { - reader.stop(0u8.into()).ok(); - error!("Received unexpected data from the provider: {chunk:?}"); + let rest = reader.read_bytes(1).await?; + if !rest.is_empty() { + error!("Unexpected extra data at the end of the stream"); } - } else { - reader.stop(0u8.into()).ok(); } Ok(Stats { counters: self.misc.counters, @@ -925,6 +941,21 @@ pub mod fsm { } } + /// Error that you can get from [`AtBlobHeader::next`] + #[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: SpanTrace, + })] + #[non_exhaustive] + #[derive(Debug, Snafu)] + #[snafu(module)] + pub enum AtClosingNextError { + /// Generic io error + #[snafu(transparent)] + Read { source: io::Error }, + } + #[derive(Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq, Eq)] pub struct RequestCounters { /// payload bytes written @@ -950,71 +981,3 @@ pub mod fsm { ranges_iter: RangesIter, } } - -/// Error when processing a response -#[common_fields({ - backtrace: Option, - #[snafu(implicit)] - span_trace: SpanTrace, -})] -#[allow(missing_docs)] -#[non_exhaustive] -#[derive(Debug, Snafu)] -pub enum GetResponseError { - /// Error when opening a stream - #[snafu(display("connection: {source}"))] - Connection { source: endpoint::ConnectionError }, - /// Error when writing the handshake or request to the stream - #[snafu(display("write: {source}"))] - Write { source: endpoint::WriteError }, - /// Error when reading from the stream - #[snafu(display("read: {source}"))] - Read { source: endpoint::ReadError }, - /// Error when decoding, e.g. hash mismatch - #[snafu(display("decode: {source}"))] - Decode { source: bao_tree::io::DecodeError }, - /// A generic error - #[snafu(display("generic: {source}"))] - Generic { source: anyhow::Error }, -} - -impl From for GetResponseError { - fn from(cause: postcard::Error) -> Self { - GenericSnafu.into_error(cause.into()) - } -} - -impl From for GetResponseError { - fn from(cause: bao_tree::io::DecodeError) -> Self { - match cause { - bao_tree::io::DecodeError::Io(cause) => { - // try to downcast to specific quinn errors - if let Some(source) = cause.source() { - if let Some(error) = source.downcast_ref::() { - return ConnectionSnafu.into_error(error.clone()); - } - if let Some(error) = source.downcast_ref::() { - return ReadSnafu.into_error(error.clone()); - } - if let Some(error) = source.downcast_ref::() { - return WriteSnafu.into_error(error.clone()); - } - } - GenericSnafu.into_error(cause.into()) - } - _ => DecodeSnafu.into_error(cause), - } - } -} - -impl From for GetResponseError { - fn from(cause: anyhow::Error) -> Self { - GenericSnafu.into_error(cause) - } -} - -impl From for std::io::Error { - fn from(cause: GetResponseError) -> Self { - Self::other(cause) - } -} diff --git a/src/get/error.rs b/src/get/error.rs index 1c3ea9465..c949593e7 100644 --- a/src/get/error.rs +++ b/src/get/error.rs @@ -1,102 +1,15 @@ //! Error returned from get operations use std::io; -use iroh::endpoint::{self, ClosedStream}; +use iroh::endpoint::{ConnectionError, ReadError, VarInt, WriteError}; use n0_snafu::SpanTrace; use nested_enum_utils::common_fields; -use quinn::{ConnectionError, ReadError, WriteError}; -use snafu::{Backtrace, IntoError, Snafu}; +use snafu::{Backtrace, Snafu}; -use crate::{ - api::ExportBaoError, - get::fsm::{AtBlobHeaderNextError, ConnectedNextError, DecodeError}, +use crate::get::fsm::{ + AtBlobHeaderNextError, AtClosingNextError, ConnectedNextError, DecodeError, InitialNextError, }; -#[derive(Debug, Snafu)] -pub enum NotFoundCases { - #[snafu(transparent)] - AtBlobHeaderNext { source: AtBlobHeaderNextError }, - #[snafu(transparent)] - Decode { source: DecodeError }, -} - -#[derive(Debug, Snafu)] -pub enum NoncompliantNodeCases { - #[snafu(transparent)] - Connection { source: ConnectionError }, - #[snafu(transparent)] - Decode { source: DecodeError }, -} - -#[derive(Debug, Snafu)] -pub enum RemoteResetCases { - #[snafu(transparent)] - Read { source: ReadError }, - #[snafu(transparent)] - Write { source: WriteError }, - #[snafu(transparent)] - Connection { source: ConnectionError }, -} - -#[derive(Debug, Snafu)] -pub enum BadRequestCases { - #[snafu(transparent)] - Anyhow { source: anyhow::Error }, - #[snafu(transparent)] - Postcard { source: postcard::Error }, - #[snafu(transparent)] - ConnectedNext { source: ConnectedNextError }, -} - -#[derive(Debug, Snafu)] -pub enum LocalFailureCases { - #[snafu(transparent)] - Io { - source: io::Error, - }, - #[snafu(transparent)] - Anyhow { - source: anyhow::Error, - }, - #[snafu(transparent)] - IrpcSend { - source: irpc::channel::SendError, - }, - #[snafu(transparent)] - Irpc { - source: irpc::Error, - }, - #[snafu(transparent)] - ExportBao { - source: ExportBaoError, - }, - TokioSend {}, -} - -impl From> for LocalFailureCases { - fn from(_: tokio::sync::mpsc::error::SendError) -> Self { - LocalFailureCases::TokioSend {} - } -} - -#[derive(Debug, Snafu)] -pub enum IoCases { - #[snafu(transparent)] - Io { source: io::Error }, - #[snafu(transparent)] - ConnectionError { source: endpoint::ConnectionError }, - #[snafu(transparent)] - ReadError { source: endpoint::ReadError }, - #[snafu(transparent)] - WriteError { source: endpoint::WriteError }, - #[snafu(transparent)] - ClosedStream { source: endpoint::ClosedStream }, - #[snafu(transparent)] - ConnectedNextError { source: ConnectedNextError }, - #[snafu(transparent)] - AtBlobHeaderNextError { source: AtBlobHeaderNextError }, -} - /// Failures for a get operation #[common_fields({ backtrace: Option, @@ -105,210 +18,112 @@ pub enum IoCases { })] #[derive(Debug, Snafu)] #[snafu(visibility(pub(crate)))] +#[snafu(module)] pub enum GetError { - /// Hash not found, or a requested chunk for the hash not found. - #[snafu(display("Data for hash not found"))] - NotFound { - #[snafu(source(from(NotFoundCases, Box::new)))] - source: Box, + #[snafu(transparent)] + InitialNext { + source: InitialNextError, }, - /// Remote has reset the connection. - #[snafu(display("Remote has reset the connection"))] - RemoteReset { - #[snafu(source(from(RemoteResetCases, Box::new)))] - source: Box, + #[snafu(transparent)] + ConnectedNext { + source: ConnectedNextError, }, - /// Remote behaved in a non-compliant way. - #[snafu(display("Remote behaved in a non-compliant way"))] - NoncompliantNode { - #[snafu(source(from(NoncompliantNodeCases, Box::new)))] - source: Box, + #[snafu(transparent)] + AtBlobHeaderNext { + source: AtBlobHeaderNextError, }, - - /// Network or IO operation failed. - #[snafu(display("A network or IO operation failed"))] - Io { - #[snafu(source(from(IoCases, Box::new)))] - source: Box, + #[snafu(transparent)] + Decode { + source: DecodeError, }, - /// Our download request is invalid. - #[snafu(display("Our download request is invalid"))] - BadRequest { - #[snafu(source(from(BadRequestCases, Box::new)))] - source: Box, + #[snafu(transparent)] + IrpcSend { + source: irpc::channel::SendError, + }, + #[snafu(transparent)] + AtClosingNext { + source: AtClosingNextError, }, - /// Operation failed on the local node. - #[snafu(display("Operation failed on the local node"))] LocalFailure { - #[snafu(source(from(LocalFailureCases, Box::new)))] - source: Box, + source: anyhow::Error, + }, + BadRequest { + source: anyhow::Error, }, } -pub type GetResult = std::result::Result; - -impl From for GetError { - fn from(value: irpc::channel::SendError) -> Self { - LocalFailureSnafu.into_error(value.into()) - } -} - -impl From> for GetError { - fn from(value: tokio::sync::mpsc::error::SendError) -> Self { - LocalFailureSnafu.into_error(value.into()) - } -} - -impl From for GetError { - fn from(value: endpoint::ConnectionError) -> Self { - // explicit match just to be sure we are taking everything into account - use endpoint::ConnectionError; - match value { - e @ ConnectionError::VersionMismatch => { - // > The peer doesn't implement any supported version - // unsupported version is likely a long time error, so this peer is not usable - NoncompliantNodeSnafu.into_error(e.into()) - } - e @ ConnectionError::TransportError(_) => { - // > The peer violated the QUIC specification as understood by this implementation - // bad peer we don't want to keep around - NoncompliantNodeSnafu.into_error(e.into()) - } - e @ ConnectionError::ConnectionClosed(_) => { - // > The peer's QUIC stack aborted the connection automatically - // peer might be disconnecting or otherwise unavailable, drop it - IoSnafu.into_error(e.into()) - } - e @ ConnectionError::ApplicationClosed(_) => { - // > The peer closed the connection - // peer might be disconnecting or otherwise unavailable, drop it - IoSnafu.into_error(e.into()) - } - e @ ConnectionError::Reset => { - // > The peer is unable to continue processing this connection, usually due to having restarted - RemoteResetSnafu.into_error(e.into()) - } - e @ ConnectionError::TimedOut => { - // > Communication with the peer has lapsed for longer than the negotiated idle timeout - IoSnafu.into_error(e.into()) - } - e @ ConnectionError::LocallyClosed => { - // > The local application closed the connection - // TODO(@divma): don't see how this is reachable but let's just not use the peer - IoSnafu.into_error(e.into()) - } - e @ ConnectionError::CidsExhausted => { - // > The connection could not be created because not enough of the CID space - // > is available - IoSnafu.into_error(e.into()) - } - } - } -} - -impl From for GetError { - fn from(value: endpoint::ReadError) -> Self { - use endpoint::ReadError; - match value { - e @ ReadError::Reset(_) => RemoteResetSnafu.into_error(e.into()), - ReadError::ConnectionLost(conn_error) => conn_error.into(), - ReadError::ClosedStream - | ReadError::IllegalOrderedRead - | ReadError::ZeroRttRejected => { - // all these errors indicate the peer is not usable at this moment - IoSnafu.into_error(value.into()) - } +impl GetError { + pub fn iroh_error_code(&self) -> Option { + if let Some(ReadError::Reset(code)) = self + .remote_read() + .and_then(|source| source.get_ref()) + .and_then(|e| e.downcast_ref::()) + { + Some(*code) + } else if let Some(WriteError::Stopped(code)) = self + .remote_write() + .and_then(|source| source.get_ref()) + .and_then(|e| e.downcast_ref::()) + { + Some(*code) + } else if let Some(ConnectionError::ApplicationClosed(ac)) = self + .open() + .and_then(|source| source.get_ref()) + .and_then(|e| e.downcast_ref::()) + { + Some(ac.error_code) + } else { + None } } -} -impl From for GetError { - fn from(value: ClosedStream) -> Self { - IoSnafu.into_error(value.into()) - } -} -impl From for GetError { - fn from(value: quinn::WriteError) -> Self { - use quinn::WriteError; - match value { - e @ WriteError::Stopped(_) => RemoteResetSnafu.into_error(e.into()), - WriteError::ConnectionLost(conn_error) => conn_error.into(), - WriteError::ClosedStream | WriteError::ZeroRttRejected => { - // all these errors indicate the peer is not usable at this moment - IoSnafu.into_error(value.into()) - } + pub fn remote_write(&self) -> Option<&io::Error> { + match self { + Self::ConnectedNext { + source: ConnectedNextError::Write { source, .. }, + .. + } => Some(&source), + _ => None, } } -} -impl From for GetError { - fn from(value: crate::get::fsm::ConnectedNextError) -> Self { - use crate::get::fsm::ConnectedNextError::*; - match value { - e @ PostcardSer { .. } => { - // serialization errors indicate something wrong with the request itself - BadRequestSnafu.into_error(e.into()) - } - e @ RequestTooBig { .. } => { - // request will never be sent, drop it - BadRequestSnafu.into_error(e.into()) - } - Write { source, .. } => source.into(), - Closed { source, .. } => source.into(), - e @ Io { .. } => { - // io errors are likely recoverable - IoSnafu.into_error(e.into()) - } + pub fn open(&self) -> Option<&io::Error> { + match self { + Self::InitialNext { + source: InitialNextError::Open { source, .. }, + .. + } => Some(&source), + _ => None, } } -} -impl From for GetError { - fn from(value: crate::get::fsm::AtBlobHeaderNextError) -> Self { - use crate::get::fsm::AtBlobHeaderNextError::*; - match value { - e @ NotFound { .. } => { - // > This indicates that the provider does not have the requested data. - // peer might have the data later, simply retry it - NotFoundSnafu.into_error(e.into()) - } - EndpointRead { source, .. } => source.into(), - e @ Io { .. } => { - // io errors are likely recoverable - IoSnafu.into_error(e.into()) - } + pub fn remote_read(&self) -> Option<&io::Error> { + match self { + Self::AtBlobHeaderNext { + source: AtBlobHeaderNextError::Read { source, .. }, + .. + } => Some(&source), + Self::Decode { + source: DecodeError::Read { source, .. }, + .. + } => Some(&source), + Self::AtClosingNext { + source: AtClosingNextError::Read { source, .. }, + .. + } => Some(&source), + _ => None, } } -} - -impl From for GetError { - fn from(value: crate::get::fsm::DecodeError) -> Self { - use crate::get::fsm::DecodeError::*; - match value { - e @ ChunkNotFound { .. } => NotFoundSnafu.into_error(e.into()), - e @ ParentNotFound { .. } => NotFoundSnafu.into_error(e.into()), - e @ LeafNotFound { .. } => NotFoundSnafu.into_error(e.into()), - e @ ParentHashMismatch { .. } => { - // TODO(@divma): did the peer sent wrong data? is it corrupted? did we sent a wrong - // request? - NoncompliantNodeSnafu.into_error(e.into()) - } - e @ LeafHashMismatch { .. } => { - // TODO(@divma): did the peer sent wrong data? is it corrupted? did we sent a wrong - // request? - NoncompliantNodeSnafu.into_error(e.into()) - } - Read { source, .. } => source.into(), - DecodeIo { source, .. } => source.into(), + pub fn local_write(&self) -> Option<&io::Error> { + match self { + Self::Decode { + source: DecodeError::Write { source, .. }, + .. + } => Some(&source), + _ => None, } } } -impl From for GetError { - fn from(value: std::io::Error) -> Self { - // generally consider io errors recoverable - // we might want to revisit this at some point - IoSnafu.into_error(value.into()) - } -} +pub type GetResult = std::result::Result; diff --git a/src/get/request.rs b/src/get/request.rs index 98563057e..2da557640 100644 --- a/src/get/request.rs +++ b/src/get/request.rs @@ -25,7 +25,7 @@ use tokio::sync::mpsc; use super::{fsm, GetError, GetResult, Stats}; use crate::{ - get::error::{BadRequestSnafu, LocalFailureSnafu}, + get::get_error::{BadRequestSnafu, LocalFailureSnafu}, hashseq::HashSeq, protocol::{ChunkRangesExt, ChunkRangesSeq, GetRequest}, Hash, HashAndFormat, From 2f9ebd54a375e4d2b673d0f288133b82fbe0408e Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 4 Sep 2025 10:43:02 +0200 Subject: [PATCH 23/35] silence some of the tests --- src/api/downloader.rs | 12 +--- src/api/remote.rs | 1 - src/get.rs | 87 +++++++++++++++++++++-------- src/store/fs/util/entity_manager.rs | 4 -- src/tests.rs | 9 +-- 5 files changed, 71 insertions(+), 42 deletions(-) diff --git a/src/api/downloader.rs b/src/api/downloader.rs index 3555eca9c..35a277f5f 100644 --- a/src/api/downloader.rs +++ b/src/api/downloader.rs @@ -563,9 +563,7 @@ mod tests { .download(request, Shuffled::new(vec![node1_id, node2_id])) .stream() .await?; - while let Some(item) = progress.next().await { - println!("Got item: {item:?}"); - } + while let Some(_) = progress.next().await {} assert_eq!(store3.get_bytes(tt1.hash).await?.deref(), b"hello world"); assert_eq!(store3.get_bytes(tt2.hash).await?.deref(), b"hello world 2"); Ok(()) @@ -608,9 +606,7 @@ mod tests { )) .stream() .await?; - while let Some(item) = progress.next().await { - println!("Got item: {item:?}"); - } + while let Some(_) = progress.next().await {} } if false { let conn = r3.endpoint().connect(node1_addr, crate::ALPN).await?; @@ -672,9 +668,7 @@ mod tests { )) .stream() .await?; - while let Some(item) = progress.next().await { - println!("Got item: {item:?}"); - } + while let Some(_) = progress.next().await {} Ok(()) } } diff --git a/src/api/remote.rs b/src/api/remote.rs index 0d729081e..b357133b3 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -759,7 +759,6 @@ impl Remote { Err(at_closing) => break at_closing, }; let offset = at_start_child.offset(); - println!("offset {offset}"); let Some(hash) = hash_seq.get(offset as usize) else { break at_start_child.finish(); }; diff --git a/src/get.rs b/src/get.rs index 6032857e4..d7c364392 100644 --- a/src/get.rs +++ b/src/get.rs @@ -18,18 +18,20 @@ //! [iroh]: https://docs.rs/iroh use std::{ fmt::{self, Debug}, + io, time::{Duration, Instant}, }; use anyhow::Result; use bao_tree::{io::fsm::BaoContentItem, ChunkNum}; use fsm::RequestCounters; -use iroh::endpoint::{RecvStream, SendStream}; -use iroh_io::{TokioStreamReader, TokioStreamWriter}; +use iroh_io::{AsyncStreamReader, AsyncStreamWriter}; use n0_snafu::SpanTrace; use nested_enum_utils::common_fields; +use quinn::ReadExactError; use serde::{Deserialize, Serialize}; use snafu::{Backtrace, IntoError, ResultExt, Snafu}; +use tokio::io::AsyncWriteExt; use tracing::{debug, error}; use crate::{protocol::ChunkRangesSeq, store::IROH_BLOCK_SIZE, Hash}; @@ -39,8 +41,49 @@ pub mod request; pub(crate) use error::get_error; pub use error::{GetError, GetResult}; -type DefaultReader = TokioStreamReader; -type DefaultWriter = TokioStreamWriter; +pub struct IrohStreamWriter(iroh::endpoint::SendStream); + +impl AsyncStreamWriter for IrohStreamWriter { + async fn write(&mut self, data: &[u8]) -> io::Result<()> { + Ok(self.0.write_all(data).await?) + } + + async fn write_bytes(&mut self, data: bytes::Bytes) -> io::Result<()> { + Ok(self.0.write_chunk(data).await?) + } + + async fn sync(&mut self) -> io::Result<()> { + Ok(self.0.flush().await?) + } +} + +pub struct IrohStreamReader(iroh::endpoint::RecvStream); + +impl AsyncStreamReader for IrohStreamReader { + async fn read(&mut self) -> io::Result<[u8; N]> { + let mut buf = [0u8; N]; + match self.0.read_exact(&mut buf).await { + Ok(()) => Ok(buf), + Err(ReadExactError::ReadError(e)) => Err(e.into()), + Err(ReadExactError::FinishedEarly(_)) => Err(io::ErrorKind::UnexpectedEof.into()), + } + } + + async fn read_bytes(&mut self, len: usize) -> io::Result { + let mut buf = vec![0u8; len]; + match self.0.read_exact(&mut buf).await { + Ok(()) => Ok(buf.into()), + Err(ReadExactError::ReadError(e)) => Err(e.into()), + Err(ReadExactError::FinishedEarly(n)) => { + buf.truncate(n); + Ok(buf.into()) + } + } + } +} + +type DefaultReader = IrohStreamReader; +type DefaultWriter = IrohStreamWriter; /// Stats about the transfer. #[derive( @@ -96,7 +139,7 @@ pub mod fsm { }; use derive_more::From; use iroh::endpoint::Connection; - use iroh_io::{AsyncSliceWriter, AsyncStreamReader, AsyncStreamWriter, TokioStreamReader}; + use iroh_io::{AsyncSliceWriter, AsyncStreamReader, AsyncStreamWriter}; use super::*; use crate::{ @@ -134,8 +177,8 @@ pub mod fsm { .open_bi() .await .map_err(|e| OpenSnafu.into_error(e.into()))?; - let reader = TokioStreamReader::new(reader); - let mut writer = TokioStreamWriter(writer); + let reader = IrohStreamReader(reader); + let mut writer = IrohStreamWriter(writer); let request = Request::GetMany(request); let request_bytes = postcard::to_stdvec(&request) .map_err(|source| BadRequestSnafu.into_error(source.into()))?; @@ -227,8 +270,8 @@ pub mod fsm { .open_bi() .await .map_err(|e| OpenSnafu.into_error(e.into()))?; - let reader = TokioStreamReader::new(reader); - let writer = TokioStreamWriter(writer); + let reader = IrohStreamReader(reader); + let writer = IrohStreamWriter(writer); Ok(AtConnected { start, reader, @@ -375,7 +418,7 @@ pub mod fsm { /// State of the get response when we start reading a collection #[derive(Debug)] - pub struct AtStartRoot> { + pub struct AtStartRoot { ranges: ChunkRanges, reader: R, misc: Box, @@ -384,7 +427,7 @@ pub mod fsm { /// State of the get response when we start reading a child #[derive(Debug)] - pub struct AtStartChild> { + pub struct AtStartChild { ranges: ChunkRanges, reader: R, misc: Box, @@ -459,7 +502,7 @@ pub mod fsm { /// State before reading a size header #[derive(Debug)] - pub struct AtBlobHeader> { + pub struct AtBlobHeader { ranges: ChunkRanges, reader: R, misc: Box, @@ -587,7 +630,7 @@ pub mod fsm { /// State while we are reading content #[derive(Debug)] - pub struct AtBlobContent> { + pub struct AtBlobContent { stream: ResponseDecoder, misc: Box, } @@ -683,17 +726,17 @@ pub mod fsm { impl From for DecodeError { fn from(value: bao_tree::io::DecodeError) -> Self { match value { - bao_tree::io::DecodeError::ParentNotFound(x) => { - decode_error::ParentNotFoundSnafu { node: x }.build() + bao_tree::io::DecodeError::ParentNotFound(node) => { + decode_error::ParentNotFoundSnafu { node }.build() } - bao_tree::io::DecodeError::LeafNotFound(x) => { - decode_error::LeafNotFoundSnafu { num: x }.build() + bao_tree::io::DecodeError::LeafNotFound(num) => { + decode_error::LeafNotFoundSnafu { num }.build() } bao_tree::io::DecodeError::ParentHashMismatch(node) => { decode_error::ParentHashMismatchSnafu { node }.build() } - bao_tree::io::DecodeError::LeafHashMismatch(chunk) => { - decode_error::LeafHashMismatchSnafu { num: chunk }.build() + bao_tree::io::DecodeError::LeafHashMismatch(num) => { + decode_error::LeafHashMismatchSnafu { num }.build() } bao_tree::io::DecodeError::Io(cause) => decode_error::ReadSnafu.into_error(cause), } @@ -876,14 +919,14 @@ pub mod fsm { /// State after we have read all the content for a blob #[derive(Debug)] - pub struct AtEndBlob> { + pub struct AtEndBlob { stream: R, misc: Box, } /// The next state after the end of a blob #[derive(Debug, From)] - pub enum EndBlobNext> { + pub enum EndBlobNext { /// Response is expected to have more children MoreChildren(AtStartChild), /// No more children expected @@ -909,7 +952,7 @@ pub mod fsm { /// State when finishing the get response #[derive(Debug)] - pub struct AtClosing> { + pub struct AtClosing { misc: Box, reader: R, check_extra_data: bool, diff --git a/src/store/fs/util/entity_manager.rs b/src/store/fs/util/entity_manager.rs index 91a737d76..b0b2898ea 100644 --- a/src/store/fs/util/entity_manager.rs +++ b/src/store/fs/util/entity_manager.rs @@ -1186,10 +1186,6 @@ mod tests { .spawn(id, move |arg| async move { match arg { SpawnArg::Active(state) => { - println!( - "Adding value {} to entity actor with id {:?}", - value, state.id - ); state .with_value(|v| *v = v.wrapping_add(value)) .await diff --git a/src/tests.rs b/src/tests.rs index 0ef0c027c..0b3b2fde2 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -556,6 +556,7 @@ async fn two_nodes_hash_seq( } #[tokio::test] + async fn two_nodes_hash_seq_fs() -> TestResult<()> { tracing_subscriber::fmt::try_init().ok(); let (_testdir, (r1, store1, _), (r2, store2, _)) = two_node_test_setup_fs().await?; @@ -578,9 +579,7 @@ async fn two_nodes_hash_seq_progress() -> TestResult<()> { let root = add_test_hash_seq(&store1, sizes).await?; let conn = r2.endpoint().connect(addr1, crate::ALPN).await?; let mut stream = store2.remote().fetch(conn, root).stream(); - while let Some(item) = stream.next().await { - println!("{item:?}"); - } + while let Some(_) = stream.next().await {} check_presence(&store2, &sizes).await?; Ok(()) } @@ -648,9 +647,7 @@ async fn node_serve_blobs() -> TestResult<()> { let expected = test_data(size); let hash = Hash::new(&expected); let mut stream = get::request::get_blob(conn.clone(), hash); - while let Some(item) = stream.next().await { - println!("{item:?}"); - } + while let Some(_) = stream.next().await {} let actual = get::request::get_blob(conn.clone(), hash).await?; assert_eq!(actual.len(), expected.len(), "size: {size}"); } From d764dc0d48aa50948cbf07e0eed04da89ad8c592 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 4 Sep 2025 11:16:36 +0200 Subject: [PATCH 24/35] Genericize provider side a bit --- src/api/blobs.rs | 42 +++++++++++------------------------------- src/api/remote.rs | 16 ++++++---------- src/get.rs | 4 ++-- src/provider.rs | 31 ++++++++++++++++++------------- 4 files changed, 37 insertions(+), 56 deletions(-) diff --git a/src/api/blobs.rs b/src/api/blobs.rs index 1822be5b2..830dd2042 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -23,14 +23,12 @@ use bao_tree::{ }; use bytes::Bytes; use genawaiter::sync::Gen; -use iroh_io::{AsyncStreamReader, TokioStreamReader}; +use iroh_io::{AsyncStreamReader, AsyncStreamWriter}; use irpc::channel::{mpsc, oneshot}; use n0_future::{future, stream, Stream, StreamExt}; -use quinn::SendStream; use range_collections::{range_set::RangeSetRange, RangeSet2}; use ref_cast::RefCast; use serde::{Deserialize, Serialize}; -use tokio::io::AsyncWriteExt; use tracing::trace; mod reader; pub use reader::BlobReader; @@ -431,7 +429,7 @@ impl Blobs { } #[cfg_attr(feature = "hide-proto-docs", doc(hidden))] - async fn import_bao_reader( + pub async fn import_bao_reader( &self, hash: Hash, ranges: ChunkRanges, @@ -468,18 +466,6 @@ impl Blobs { Ok(reader?) } - #[cfg_attr(feature = "hide-proto-docs", doc(hidden))] - pub async fn import_bao_quinn( - &self, - hash: Hash, - ranges: ChunkRanges, - stream: &mut iroh::endpoint::RecvStream, - ) -> RequestResult<()> { - let reader = TokioStreamReader::new(stream); - self.import_bao_reader(hash, ranges, reader).await?; - Ok(()) - } - #[cfg_attr(feature = "hide-proto-docs", doc(hidden))] pub async fn import_bao_bytes( &self, @@ -1058,24 +1044,21 @@ impl ExportBaoProgress { Ok(data) } - pub async fn write_quinn(self, target: &mut quinn::SendStream) -> super::ExportBaoResult<()> { + pub async fn write(self, target: &mut W) -> super::ExportBaoResult<()> { let mut rx = self.inner.await?; while let Some(item) = rx.recv().await? { match item { EncodedItem::Size(size) => { - target.write_u64_le(size).await?; + target.write(&size.to_le_bytes()).await?; } EncodedItem::Parent(parent) => { let mut data = vec![0u8; 64]; data[..32].copy_from_slice(parent.pair.0.as_bytes()); data[32..].copy_from_slice(parent.pair.1.as_bytes()); - target.write_all(&data).await.map_err(io::Error::from)?; + target.write(&data).await?; } EncodedItem::Leaf(leaf) => { - target - .write_chunk(leaf.data) - .await - .map_err(io::Error::from)?; + target.write_bytes(leaf.data).await?; } EncodedItem::Done => break, EncodedItem::Error(cause) => return Err(cause.into()), @@ -1085,9 +1068,9 @@ impl ExportBaoProgress { } /// Write quinn variant that also feeds a progress writer. - pub(crate) async fn write_quinn_with_progress( + pub(crate) async fn write_with_progress( self, - writer: &mut SendStream, + writer: &mut W, progress: &mut impl WriteProgress, hash: &Hash, index: u64, @@ -1097,22 +1080,19 @@ impl ExportBaoProgress { match item { EncodedItem::Size(size) => { progress.send_transfer_started(index, hash, size).await; - writer.write_u64_le(size).await?; + writer.write(&size.to_le_bytes()).await?; progress.log_other_write(8); } EncodedItem::Parent(parent) => { let mut data = vec![0u8; 64]; data[..32].copy_from_slice(parent.pair.0.as_bytes()); data[32..].copy_from_slice(parent.pair.1.as_bytes()); - writer.write_all(&data).await.map_err(io::Error::from)?; + writer.write(&data).await?; progress.log_other_write(64); } EncodedItem::Leaf(leaf) => { let len = leaf.data.len(); - writer - .write_chunk(leaf.data) - .await - .map_err(io::Error::from)?; + writer.write_bytes(leaf.data).await?; progress .notify_payload_write(index, leaf.offset, len) .await?; diff --git a/src/api/remote.rs b/src/api/remote.rs index b357133b3..ec9fc7874 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -16,7 +16,7 @@ use crate::{ get::{ fsm::DecodeError, get_error::{BadRequestSnafu, LocalFailureSnafu}, - GetError, GetResult, Stats, + GetError, GetResult, IrohStreamWriter, Stats, }, protocol::{ GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType, @@ -594,15 +594,16 @@ impl Remote { let mut request_ranges = request.ranges.iter_infinite(); let root = request.hash; let root_ranges = request_ranges.next().expect("infinite iterator"); + let mut send = IrohStreamWriter(send); if !root_ranges.is_empty() { self.store() .export_bao(root, root_ranges.clone()) - .write_quinn_with_progress(&mut send, &mut context, &root, 0) + .write_with_progress(&mut send, &mut context, &root, 0) .await?; } if request.ranges.is_blob() { // we are done - send.finish()?; + send.0.finish()?; return Ok(Default::default()); } let hash_seq = self.store().get_bytes(root).await?; @@ -613,16 +614,11 @@ impl Remote { if !child_ranges.is_empty() { self.store() .export_bao(child_hash, child_ranges.clone()) - .write_quinn_with_progress( - &mut send, - &mut context, - &child_hash, - (child + 1) as u64, - ) + .write_with_progress(&mut send, &mut context, &child_hash, (child + 1) as u64) .await?; } } - send.finish()?; + send.0.finish()?; Ok(Default::default()) } diff --git a/src/get.rs b/src/get.rs index d7c364392..f93eeabbb 100644 --- a/src/get.rs +++ b/src/get.rs @@ -41,7 +41,7 @@ pub mod request; pub(crate) use error::get_error; pub use error::{GetError, GetResult}; -pub struct IrohStreamWriter(iroh::endpoint::SendStream); +pub struct IrohStreamWriter(pub iroh::endpoint::SendStream); impl AsyncStreamWriter for IrohStreamWriter { async fn write(&mut self, data: &[u8]) -> io::Result<()> { @@ -57,7 +57,7 @@ impl AsyncStreamWriter for IrohStreamWriter { } } -pub struct IrohStreamReader(iroh::endpoint::RecvStream); +pub struct IrohStreamReader(pub iroh::endpoint::RecvStream); impl AsyncStreamReader for IrohStreamReader { async fn read(&mut self) -> io::Result<[u8; N]> { diff --git a/src/provider.rs b/src/provider.rs index 0134169c6..0d9da3d5d 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -12,6 +12,7 @@ use std::{ use anyhow::{Context, Result}; use bao_tree::ChunkRanges; use iroh::endpoint::{self, RecvStream, SendStream}; +use iroh_io::{AsyncStreamReader, AsyncStreamWriter}; use n0_future::StreamExt; use quinn::{ClosedStream, ConnectionError, ReadToEndError}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -23,6 +24,7 @@ use crate::{ blobs::{Bitfield, WriteProgress}, ExportBaoResult, Store, }, + get::{IrohStreamReader, IrohStreamWriter}, hashseq::HashSeq, protocol::{GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request}, provider::events::{ClientConnected, ClientResult, ConnectionClosed, RequestTracker}, @@ -31,6 +33,9 @@ use crate::{ pub mod events; use events::EventSender; +type DefaultWriter = IrohStreamWriter; +type DefaultReader = IrohStreamReader; + /// Statistics about a successful or failed transfer. #[derive(Debug, Serialize, Deserialize)] pub struct TransferStats { @@ -106,7 +111,7 @@ impl StreamPair { return Err(e); }; Ok(ProgressWriter::new( - self.writer, + IrohStreamWriter(self.writer), WriterContext { t0: self.t0, other_bytes_read: self.other_bytes_read, @@ -130,7 +135,7 @@ impl StreamPair { return Err(e); }; Ok(ProgressReader { - inner: self.reader, + inner: IrohStreamReader(self.reader), context: ReaderContext { t0: self.t0, other_bytes_read: self.other_bytes_read, @@ -282,14 +287,14 @@ impl WriteProgress for WriterContext { /// Wrapper for a [`quinn::SendStream`] with additional per request information. #[derive(Debug)] -pub struct ProgressWriter { +pub struct ProgressWriter { /// The quinn::SendStream to write to - pub inner: SendStream, + pub inner: W, pub(crate) context: WriterContext, } -impl ProgressWriter { - fn new(inner: SendStream, context: WriterContext) -> Self { +impl ProgressWriter { + fn new(inner: W, context: WriterContext) -> Self { Self { inner, context } } @@ -465,7 +470,7 @@ pub async fn handle_push( if !root_ranges.is_empty() { // todo: send progress from import_bao_quinn or rename to import_bao_quinn_with_progress store - .import_bao_quinn(hash, root_ranges.clone(), &mut reader.inner) + .import_bao_reader(hash, root_ranges.clone(), &mut reader.inner) .await?; } if request.ranges.is_blob() { @@ -480,7 +485,7 @@ pub async fn handle_push( continue; } store - .import_bao_quinn(child_hash, child_ranges.clone(), &mut reader.inner) + .import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner) .await?; } Ok(()) @@ -496,7 +501,7 @@ pub(crate) async fn send_blob( ) -> ExportBaoResult<()> { store .export_bao(hash, ranges) - .write_quinn_with_progress(&mut writer.inner, &mut writer.context, &hash, index) + .write_with_progress(&mut writer.inner, &mut writer.context, &hash, index) .await } @@ -527,7 +532,7 @@ pub async fn handle_observe( send_observe_item(writer, &diff).await?; old = new; } - _ = writer.inner.stopped() => { + _ = writer.inner.0.stopped() => { debug!("observer closed"); break; } @@ -539,13 +544,13 @@ pub async fn handle_observe( async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> Result<()> { use irpc::util::AsyncWriteVarintExt; let item = ObserveItem::from(item); - let len = writer.inner.write_length_prefixed(item).await?; + let len = writer.inner.0.write_length_prefixed(item).await?; writer.context.log_other_write(len); Ok(()) } -pub struct ProgressReader { - inner: RecvStream, +pub struct ProgressReader { + inner: R, context: ReaderContext, } From 4e8387a071804aa6d87d443cf247c4dbca81fe97 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 4 Sep 2025 16:04:44 +0200 Subject: [PATCH 25/35] Refactor error and make get and provide side generic --- src/api.rs | 6 +- src/api/downloader.rs | 6 +- src/api/remote.rs | 10 +- src/get.rs | 4 +- src/get/error.rs | 12 +- src/get/request.rs | 8 +- src/protocol.rs | 5 +- src/provider.rs | 547 +++++++++++++++++++++++++++++------------ src/provider/events.rs | 10 +- src/tests.rs | 4 +- 10 files changed, 428 insertions(+), 184 deletions(-) diff --git a/src/api.rs b/src/api.rs index 117c59e25..dc38498d3 100644 --- a/src/api.rs +++ b/src/api.rs @@ -98,7 +98,7 @@ pub enum ExportBaoError { #[snafu(display("encode error: {source}"))] ExportBaoInner { source: bao_tree::io::EncodeError }, #[snafu(display("client error: {source}"))] - ClientError { source: ProgressError }, + Progress { source: ProgressError }, } impl From for Error { @@ -109,7 +109,7 @@ impl From for Error { ExportBaoError::Request { source, .. } => Self::Io(source.into()), ExportBaoError::ExportBaoIo { source, .. } => Self::Io(source), ExportBaoError::ExportBaoInner { source, .. } => Self::Io(source.into()), - ExportBaoError::ClientError { source, .. } => Self::Io(source.into()), + ExportBaoError::Progress { source, .. } => Self::Io(source.into()), } } } @@ -157,7 +157,7 @@ impl From for ExportBaoError { impl From for ExportBaoError { fn from(value: ProgressError) -> Self { - ClientSnafu.into_error(value) + ProgressSnafu.into_error(value) } } diff --git a/src/api/downloader.rs b/src/api/downloader.rs index 35a277f5f..8ac188000 100644 --- a/src/api/downloader.rs +++ b/src/api/downloader.rs @@ -563,7 +563,7 @@ mod tests { .download(request, Shuffled::new(vec![node1_id, node2_id])) .stream() .await?; - while let Some(_) = progress.next().await {} + while progress.next().await.is_some() {} assert_eq!(store3.get_bytes(tt1.hash).await?.deref(), b"hello world"); assert_eq!(store3.get_bytes(tt2.hash).await?.deref(), b"hello world 2"); Ok(()) @@ -606,7 +606,7 @@ mod tests { )) .stream() .await?; - while let Some(_) = progress.next().await {} + while progress.next().await.is_some() {} } if false { let conn = r3.endpoint().connect(node1_addr, crate::ALPN).await?; @@ -668,7 +668,7 @@ mod tests { )) .stream() .await?; - while let Some(_) = progress.next().await {} + while progress.next().await.is_some() {} Ok(()) } } diff --git a/src/api/remote.rs b/src/api/remote.rs index ec9fc7874..8877c6297 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -99,8 +99,7 @@ impl GetProgress { pub async fn complete(self) -> GetResult { just_result(self.stream()).await.unwrap_or_else(|| { - Err(LocalFailureSnafu - .into_error(anyhow::anyhow!("stream closed without result").into())) + Err(LocalFailureSnafu.into_error(anyhow::anyhow!("stream closed without result"))) }) } } @@ -512,7 +511,7 @@ impl Remote { let local = self .local(content) .await - .map_err(|e: anyhow::Error| LocalFailureSnafu.into_error(e.into()))?; + .map_err(|e: anyhow::Error| LocalFailureSnafu.into_error(e))?; if local.is_complete() { return Ok(Default::default()); } @@ -520,7 +519,7 @@ impl Remote { let conn = conn .connection() .await - .map_err(|e| LocalFailureSnafu.into_error(e.into()))?; + .map_err(|e| LocalFailureSnafu.into_error(e))?; let stats = self.execute_get_sink(&conn, request, progress).await?; Ok(stats) } @@ -914,8 +913,7 @@ async fn get_blob_ranges_impl( }; let complete = async move { handle.rx.await.map_err(|e| { - LocalFailureSnafu - .into_error(anyhow::anyhow!("error reading from import stream: {e}").into()) + LocalFailureSnafu.into_error(anyhow::anyhow!("error reading from import stream: {e}")) }) }; let (_, end) = tokio::try_join!(complete, write)?; diff --git a/src/get.rs b/src/get.rs index f93eeabbb..bbe8ad26b 100644 --- a/src/get.rs +++ b/src/get.rs @@ -53,7 +53,7 @@ impl AsyncStreamWriter for IrohStreamWriter { } async fn sync(&mut self) -> io::Result<()> { - Ok(self.0.flush().await?) + self.0.flush().await } } @@ -716,7 +716,7 @@ pub mod fsm { DecodeError::LeafNotFound { .. } => { io::Error::new(io::ErrorKind::UnexpectedEof, cause) } - DecodeError::Read { source, .. } => source.into(), + DecodeError::Read { source, .. } => source, DecodeError::Write { source, .. } => source, _ => io::Error::other(cause), } diff --git a/src/get/error.rs b/src/get/error.rs index c949593e7..5cc44e35b 100644 --- a/src/get/error.rs +++ b/src/get/error.rs @@ -82,7 +82,7 @@ impl GetError { Self::ConnectedNext { source: ConnectedNextError::Write { source, .. }, .. - } => Some(&source), + } => Some(source), _ => None, } } @@ -92,7 +92,7 @@ impl GetError { Self::InitialNext { source: InitialNextError::Open { source, .. }, .. - } => Some(&source), + } => Some(source), _ => None, } } @@ -102,15 +102,15 @@ impl GetError { Self::AtBlobHeaderNext { source: AtBlobHeaderNextError::Read { source, .. }, .. - } => Some(&source), + } => Some(source), Self::Decode { source: DecodeError::Read { source, .. }, .. - } => Some(&source), + } => Some(source), Self::AtClosingNext { source: AtClosingNextError::Read { source, .. }, .. - } => Some(&source), + } => Some(source), _ => None, } } @@ -120,7 +120,7 @@ impl GetError { Self::Decode { source: DecodeError::Write { source, .. }, .. - } => Some(&source), + } => Some(source), _ => None, } } diff --git a/src/get/request.rs b/src/get/request.rs index 2da557640..c1dc034d3 100644 --- a/src/get/request.rs +++ b/src/get/request.rs @@ -58,7 +58,7 @@ impl GetBlobResult { let mut parts = Vec::new(); let stats = loop { let Some(item) = self.next().await else { - return Err(LocalFailureSnafu.into_error(anyhow::anyhow!("unexpected end").into())); + return Err(LocalFailureSnafu.into_error(anyhow::anyhow!("unexpected end"))); }; match item { GetBlobItem::Item(item) => { @@ -238,11 +238,11 @@ pub async fn get_hash_seq_and_sizes( let (at_blob_content, size) = at_start_root.next().await?; // check the size to avoid parsing a maliciously large hash seq if size > max_size { - return Err(BadRequestSnafu.into_error(anyhow::anyhow!("size too large").into())); + return Err(BadRequestSnafu.into_error(anyhow::anyhow!("size too large"))); } let (mut curr, hash_seq) = at_blob_content.concatenate_into_vec().await?; - let hash_seq = HashSeq::try_from(Bytes::from(hash_seq)) - .map_err(|e| BadRequestSnafu.into_error(e.into()))?; + let hash_seq = + HashSeq::try_from(Bytes::from(hash_seq)).map_err(|e| BadRequestSnafu.into_error(e))?; let mut sizes = Vec::with_capacity(hash_seq.len()); let closing = loop { match curr.next() { diff --git a/src/protocol.rs b/src/protocol.rs index ce10865a5..8aed6539a 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -382,7 +382,7 @@ use bao_tree::{io::round_up_to_chunks, ChunkNum}; use builder::GetRequestBuilder; use derive_more::From; use iroh::endpoint::VarInt; -use irpc::util::AsyncReadVarintExt; +use iroh_io::AsyncStreamReader; use postcard::experimental::max_size::MaxSize; use range_collections::{range_set::RangeSetEntry, RangeSet2}; use serde::{Deserialize, Serialize}; @@ -390,7 +390,6 @@ mod range_spec; pub use bao_tree::ChunkRanges; pub use range_spec::{ChunkRangesSeq, NonEmptyRequestRangeSpecIter, RangeSpec}; use snafu::{GenerateImplicitData, Snafu}; -use tokio::io::AsyncReadExt; use crate::{api::blobs::Bitfield, provider::RecvStreamExt, BlobFormat, Hash, HashAndFormat}; @@ -448,7 +447,7 @@ pub enum RequestType { } impl Request { - pub async fn read_async(reader: &mut iroh::endpoint::RecvStream) -> io::Result<(Self, usize)> { + pub async fn read_async(reader: &mut R) -> io::Result<(Self, usize)> { let request_type = reader.read_u8().await?; let request_type: RequestType = postcard::from_bytes(std::slice::from_ref(&request_type)) .map_err(|_| { diff --git a/src/provider.rs b/src/provider.rs index 0d9da3d5d..e81d783ad 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -5,29 +5,36 @@ //! handler with an [`iroh::Endpoint`](iroh::protocol::Router). use std::{ fmt::Debug, + future::Future, io, time::{Duration, Instant}, }; -use anyhow::{Context, Result}; +use anyhow::Result; use bao_tree::ChunkRanges; -use iroh::endpoint::{self, RecvStream, SendStream}; +use iroh::endpoint; use iroh_io::{AsyncStreamReader, AsyncStreamWriter}; use n0_future::StreamExt; -use quinn::{ClosedStream, ConnectionError, ReadToEndError}; +use quinn::{ConnectionError, VarInt}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use snafu::Snafu; use tokio::select; use tracing::{debug, debug_span, warn, Instrument}; use crate::{ api::{ blobs::{Bitfield, WriteProgress}, - ExportBaoResult, Store, + ExportBaoError, ExportBaoResult, RequestError, Store, }, get::{IrohStreamReader, IrohStreamWriter}, hashseq::HashSeq, - protocol::{GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request}, - provider::events::{ClientConnected, ClientResult, ConnectionClosed, RequestTracker}, + protocol::{ + GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL, + }, + provider::events::{ + ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError, + RequestTracker, + }, Hash, }; pub mod events; @@ -56,12 +63,12 @@ pub struct TransferStats { /// A pair of [`SendStream`] and [`RecvStream`] with additional context data. #[derive(Debug)] -pub struct StreamPair { +pub struct StreamPair { t0: Instant, connection_id: u64, request_id: u64, - reader: RecvStream, - writer: SendStream, + reader: R, + writer: W, other_bytes_read: u64, events: EventSender, } @@ -69,18 +76,36 @@ pub struct StreamPair { impl StreamPair { pub async fn accept( conn: &endpoint::Connection, - events: &EventSender, + events: EventSender, ) -> Result { let (writer, reader) = conn.accept_bi().await?; - Ok(Self { + Ok(Self::new( + conn.stable_id() as u64, + reader.id().into(), + IrohStreamReader(reader), + IrohStreamWriter(writer), + events, + )) + } +} + +impl StreamPair { + pub fn new( + connection_id: u64, + request_id: u64, + reader: R, + writer: W, + events: EventSender, + ) -> Self { + Self { t0: Instant::now(), - connection_id: conn.stable_id() as u64, - request_id: reader.id().into(), + connection_id, + request_id, reader, writer, other_bytes_read: 0, - events: events.clone(), - }) + events, + } } /// Read the request. @@ -98,20 +123,14 @@ impl StreamPair { } /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id - async fn into_writer( + pub async fn into_writer( mut self, tracker: RequestTracker, - ) -> Result { - let res = self.reader.read_to_end(0).await; - if let Err(e) = res { - tracker - .transfer_aborted(|| Box::new(self.stats())) - .await - .ok(); - return Err(e); - }; + ) -> Result, io::Error> { + self.reader.expect_eof().await?; + drop(self.reader); Ok(ProgressWriter::new( - IrohStreamWriter(self.writer), + self.writer, WriterContext { t0: self.t0, other_bytes_read: self.other_bytes_read, @@ -122,20 +141,14 @@ impl StreamPair { )) } - async fn into_reader( + pub async fn into_reader( mut self, tracker: RequestTracker, - ) -> Result { - let res = self.writer.finish(); - if let Err(e) = res { - tracker - .transfer_aborted(|| Box::new(self.stats())) - .await - .ok(); - return Err(e); - }; + ) -> Result, io::Error> { + self.writer.sync().await?; + drop(self.writer); Ok(ProgressReader { - inner: IrohStreamReader(self.reader), + inner: self.reader, context: ReaderContext { t0: self.t0, other_bytes_read: self.other_bytes_read, @@ -145,74 +158,42 @@ impl StreamPair { } pub async fn get_request( - mut self, + &self, f: impl FnOnce() -> GetRequest, - ) -> anyhow::Result { - let res = self - .events + ) -> Result { + self.events .request(f, self.connection_id, self.request_id) - .await; - match res { - Err(e) => { - self.writer.reset(e.code()).ok(); - Err(e.into()) - } - Ok(tracker) => Ok(self.into_writer(tracker).await?), - } + .await } pub async fn get_many_request( - mut self, + &self, f: impl FnOnce() -> GetManyRequest, - ) -> anyhow::Result { - let res = self - .events + ) -> Result { + self.events .request(f, self.connection_id, self.request_id) - .await; - match res { - Err(e) => { - self.writer.reset(e.code()).ok(); - Err(e.into()) - } - Ok(tracker) => Ok(self.into_writer(tracker).await?), - } + .await } pub async fn push_request( - mut self, + &self, f: impl FnOnce() -> PushRequest, - ) -> anyhow::Result { - let res = self - .events + ) -> Result { + self.events .request(f, self.connection_id, self.request_id) - .await; - match res { - Err(e) => { - self.writer.reset(e.code()).ok(); - Err(e.into()) - } - Ok(tracker) => Ok(self.into_reader(tracker).await?), - } + .await } pub async fn observe_request( - mut self, + &self, f: impl FnOnce() -> ObserveRequest, - ) -> anyhow::Result { - let res = self - .events + ) -> Result { + self.events .request(f, self.connection_id, self.request_id) - .await; - match res { - Err(e) => { - self.writer.reset(e.code()).ok(); - Err(e.into()) - } - Ok(tracker) => Ok(self.into_writer(tracker).await?), - } + .await } - fn stats(&self) -> TransferStats { + pub fn stats(&self) -> TransferStats { TransferStats { payload_bytes_sent: 0, other_bytes_sent: 0, @@ -339,7 +320,7 @@ pub async fn handle_connection( debug!("closing connection: {cause}"); return; } - while let Ok(context) = StreamPair::accept(&connection, &progress).await { + while let Ok(context) = StreamPair::accept(&connection, progress.clone()).await { let span = debug_span!("stream", stream_id = %context.request_id); let store = store.clone(); tokio::spawn(handle_stream(store, context).instrument(span)); @@ -353,58 +334,120 @@ pub async fn handle_connection( .await } -async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result<()> { - // 1. Decode the request. - debug!("reading request"); - let request = context.read_request().await?; +/// Describes how to handle errors for a stream. +pub trait ErrorHandler { + type W: AsyncStreamWriter; + type R: AsyncStreamReader; + fn stop(reader: &mut Self::R, code: VarInt) -> impl Future; + fn reset(writer: &mut Self::W, code: VarInt) -> impl Future; +} - match request { - Request::Get(request) => { - let mut writer = context.get_request(|| request.clone()).await?; - let res = handle_get(store, request, &mut writer).await; - if res.is_ok() { - writer.transfer_completed().await; - } else { - writer.transfer_aborted().await; - } +async fn handle_read_request_result( + pair: &mut StreamPair, + r: Result, +) -> Result { + match r { + Ok(x) => Ok(x), + Err(e) => { + H::reset(&mut pair.writer, e.code()).await; + Err(e) } - Request::GetMany(request) => { - let mut writer = context.get_many_request(|| request.clone()).await?; - if handle_get_many(store, request, &mut writer).await.is_ok() { - writer.transfer_completed().await; - } else { - writer.transfer_aborted().await; - } + } +} +async fn handle_write_result( + writer: &mut ProgressWriter, + r: Result, +) -> Result { + match r { + Ok(x) => { + writer.transfer_completed().await; + Ok(x) } - Request::Observe(request) => { - let mut writer = context.observe_request(|| request.clone()).await?; - if handle_observe(store, request, &mut writer).await.is_ok() { - writer.transfer_completed().await; - } else { - writer.transfer_aborted().await; - } + Err(e) => { + H::reset(&mut writer.inner, e.code()).await; + writer.transfer_aborted().await; + Err(e) } - Request::Push(request) => { - let mut reader = context.push_request(|| request.clone()).await?; - if handle_push(store, request, &mut reader).await.is_ok() { - reader.transfer_completed().await; - } else { - reader.transfer_aborted().await; - } + } +} +async fn handle_read_result( + reader: &mut ProgressReader, + r: Result, +) -> Result { + match r { + Ok(x) => { + reader.transfer_completed().await; + Ok(x) + } + Err(e) => { + H::stop(&mut reader.inner, e.code()).await; + reader.transfer_aborted().await; + Err(e) } + } +} +struct IrohErrorHandler; + +impl ErrorHandler for IrohErrorHandler { + type W = DefaultWriter; + type R = DefaultReader; + + async fn stop(reader: &mut Self::R, code: VarInt) { + reader.0.stop(code).ok(); + } + async fn reset(writer: &mut Self::W, code: VarInt) { + writer.0.reset(code).ok(); + } +} + +pub async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result<()> { + // 1. Decode the request. + debug!("reading request"); + let request = context.read_request().await?; + type H = IrohErrorHandler; + + match request { + Request::Get(request) => handle_get::(context, store, request).await?, + Request::GetMany(request) => handle_get_many::(context, store, request).await?, + Request::Observe(request) => handle_observe::(context, store, request).await?, + Request::Push(request) => handle_push::(context, store, request).await?, _ => {} } Ok(()) } +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum HandleGetError { + #[snafu(transparent)] + ExportBao { + source: ExportBaoError, + }, + InvalidHashSeq, + InvalidOffset, +} + +impl HasErrorCode for HandleGetError { + fn code(&self) -> VarInt { + match self { + HandleGetError::ExportBao { + source: ExportBaoError::Progress { source, .. }, + } => source.code(), + HandleGetError::InvalidHashSeq => ERR_INTERNAL, + HandleGetError::InvalidOffset => ERR_INTERNAL, + _ => ERR_INTERNAL, + } + } +} + /// Handle a single get request. /// /// Requires a database, the request, and a writer. -pub async fn handle_get( +async fn handle_get_impl( store: Store, request: GetRequest, - writer: &mut ProgressWriter, -) -> anyhow::Result<()> { + writer: &mut ProgressWriter, +) -> Result<(), HandleGetError> { let hash = request.hash; debug!(%hash, "get received request"); let mut hash_seq = None; @@ -421,12 +464,13 @@ pub async fn handle_get( Some(b) => b, None => { let bytes = store.get_bytes(hash).await?; - let hs = HashSeq::try_from(bytes)?; + let hs = + HashSeq::try_from(bytes).map_err(|_| HandleGetError::InvalidHashSeq)?; hash_seq = Some(hs); hash_seq.as_ref().unwrap() } }; - let o = usize::try_from(offset - 1).context("offset too large")?; + let o = usize::try_from(offset - 1).map_err(|_| HandleGetError::InvalidOffset)?; let Some(hash) = hash_seq.get(o) else { break; }; @@ -437,14 +481,44 @@ pub async fn handle_get( Ok(()) } +pub async fn handle_get( + mut pair: StreamPair, + store: Store, + request: GetRequest, +) -> anyhow::Result<()> { + let res = pair.get_request(|| request.clone()).await; + let tracker = handle_read_request_result::(&mut pair, res).await?; + let mut writer = pair.into_writer(tracker).await?; + let res = handle_get_impl(store, request, &mut writer).await; + handle_write_result::(&mut writer, res).await?; + Ok(()) +} + +#[derive(Debug, Snafu)] +pub enum HandleGetManyError { + #[snafu(transparent)] + ExportBao { source: ExportBaoError }, +} + +impl HasErrorCode for HandleGetManyError { + fn code(&self) -> VarInt { + match self { + Self::ExportBao { + source: ExportBaoError::Progress { source, .. }, + } => source.code(), + _ => ERR_INTERNAL, + } + } +} + /// Handle a single get request. /// /// Requires a database, the request, and a writer. -pub async fn handle_get_many( +async fn handle_get_many_impl( store: Store, request: GetManyRequest, - writer: &mut ProgressWriter, -) -> Result<()> { + writer: &mut ProgressWriter, +) -> Result<(), HandleGetManyError> { debug!("get_many received request"); let request_ranges = request.ranges.iter_infinite(); for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() { @@ -455,14 +529,53 @@ pub async fn handle_get_many( Ok(()) } +pub async fn handle_get_many( + mut pair: StreamPair, + store: Store, + request: GetManyRequest, +) -> anyhow::Result<()> { + let res = pair.get_many_request(|| request.clone()).await; + let tracker = handle_read_request_result::(&mut pair, res).await?; + let mut writer = pair.into_writer(tracker).await?; + let res = handle_get_many_impl(store, request, &mut writer).await; + handle_write_result::(&mut writer, res).await?; + Ok(()) +} + +#[derive(Debug, Snafu)] +pub enum HandlePushError { + #[snafu(transparent)] + ExportBao { + source: ExportBaoError, + }, + + InvalidHashSeq, + + #[snafu(transparent)] + Request { + source: RequestError, + }, +} + +impl HasErrorCode for HandlePushError { + fn code(&self) -> VarInt { + match self { + Self::ExportBao { + source: ExportBaoError::Progress { source, .. }, + } => source.code(), + _ => ERR_INTERNAL, + } + } +} + /// Handle a single push request. /// /// Requires a database, the request, and a reader. -pub async fn handle_push( +async fn handle_push_impl( store: Store, request: PushRequest, - reader: &mut ProgressReader, -) -> Result<()> { + reader: &mut ProgressReader, +) -> Result<(), HandlePushError> { let hash = request.hash; debug!(%hash, "push received request"); let mut request_ranges = request.ranges.iter_infinite(); @@ -479,7 +592,7 @@ pub async fn handle_push( } // todo: we assume here that the hash sequence is complete. For some requests this might not be the case. We would need `LazyHashSeq` for that, but it is buggy as of now! let hash_seq = store.get_bytes(hash).await?; - let hash_seq = HashSeq::try_from(hash_seq)?; + let hash_seq = HashSeq::try_from(hash_seq).map_err(|_| HandlePushError::InvalidHashSeq)?; for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) { if child_ranges.is_empty() { continue; @@ -491,13 +604,26 @@ pub async fn handle_push( Ok(()) } +pub async fn handle_push( + mut pair: StreamPair, + store: Store, + request: PushRequest, +) -> anyhow::Result<()> { + let res = pair.push_request(|| request.clone()).await; + let tracker = handle_read_request_result::(&mut pair, res).await?; + let mut reader = pair.into_reader(tracker).await?; + let res = handle_push_impl(store, request, &mut reader).await; + handle_read_result::(&mut reader, res).await?; + Ok(()) +} + /// Send a blob to the client. -pub(crate) async fn send_blob( +pub(crate) async fn send_blob( store: &Store, index: u64, hash: Hash, ranges: ChunkRanges, - writer: &mut ProgressWriter, + writer: &mut ProgressWriter, ) -> ExportBaoResult<()> { store .export_bao(hash, ranges) @@ -505,26 +631,46 @@ pub(crate) async fn send_blob( .await } +#[derive(Debug, Snafu)] +pub enum HandleObserveError { + ObserveStreamClosed, + + #[snafu(transparent)] + RemoteClosed { + source: io::Error, + }, +} + +impl HasErrorCode for HandleObserveError { + fn code(&self) -> VarInt { + ERR_INTERNAL + } +} + /// Handle a single push request. /// /// Requires a database, the request, and a reader. -pub async fn handle_observe( +async fn handle_observe_impl( store: Store, request: ObserveRequest, writer: &mut ProgressWriter, -) -> Result<()> { - let mut stream = store.observe(request.hash).stream().await?; +) -> std::result::Result<(), HandleObserveError> { + let mut stream = store + .observe(request.hash) + .stream() + .await + .map_err(|_| HandleObserveError::ObserveStreamClosed)?; let mut old = stream .next() .await - .ok_or(anyhow::anyhow!("observe stream closed before first value"))?; + .ok_or(HandleObserveError::ObserveStreamClosed)?; // send the initial bitfield send_observe_item(writer, &old).await?; // send updates until the remote loses interest loop { select! { new = stream.next() => { - let new = new.context("observe stream closed")?; + let new = new.ok_or(HandleObserveError::ObserveStreamClosed)?; let diff = old.diff(&new); if diff.is_empty() { continue; @@ -541,7 +687,7 @@ pub async fn handle_observe( Ok(()) } -async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> Result<()> { +async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> io::Result<()> { use irpc::util::AsyncWriteVarintExt; let item = ObserveItem::from(item); let len = writer.inner.0.write_length_prefixed(item).await?; @@ -549,12 +695,25 @@ async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> Resu Ok(()) } +pub async fn handle_observe>( + mut pair: StreamPair, + store: Store, + request: ObserveRequest, +) -> anyhow::Result<()> { + let res = pair.observe_request(|| request.clone()).await; + let tracker = handle_read_request_result::(&mut pair, res).await?; + let mut writer = pair.into_writer(tracker).await?; + let res = handle_observe_impl(store, request, &mut writer).await; + handle_write_result::(&mut writer, res).await?; + Ok(()) +} + pub struct ProgressReader { inner: R, context: ReaderContext, } -impl ProgressReader { +impl ProgressReader { async fn transfer_aborted(&self) { self.context .tracker @@ -572,24 +731,106 @@ impl ProgressReader { } } -pub(crate) trait RecvStreamExt { - async fn read_to_end_as( - &mut self, - max_size: usize, - ) -> io::Result<(T, usize)>; -} +pub(crate) trait RecvStreamExt: AsyncStreamReader { + async fn expect_eof(&mut self) -> io::Result<()> { + match self.read_u8().await { + Ok(_) => Err(io::Error::new( + io::ErrorKind::InvalidData, + "unexpected data", + )), + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(()), + Err(e) => Err(e), + } + } + + async fn read_u8(&mut self) -> io::Result { + let buf = self.read::<1>().await?; + Ok(buf[0]) + } -impl RecvStreamExt for RecvStream { async fn read_to_end_as( &mut self, max_size: usize, ) -> io::Result<(T, usize)> { - let data = self - .read_to_end(max_size) - .await - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + let data = self.read_bytes(max_size).await?; + self.expect_eof().await?; let value = postcard::from_bytes(&data) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; Ok((value, data.len())) } + + async fn read_length_prefixed( + &mut self, + max_size: usize, + ) -> io::Result { + let Some(n) = self.read_varint_u64().await? else { + return Err(io::ErrorKind::UnexpectedEof.into()); + }; + if n > max_size as u64 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "length prefix too large", + )); + } + let n = n as usize; + let data = self.read_bytes(n).await?; + let value = postcard::from_bytes(&data) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(value) + } + + /// Reads a u64 varint from an AsyncRead source, using the Postcard/LEB128 format. + /// + /// In Postcard's varint format (LEB128): + /// - Each byte uses 7 bits for the value + /// - The MSB (most significant bit) of each byte indicates if there are more bytes (1) or not (0) + /// - Values are stored in little-endian order (least significant group first) + /// + /// Returns the decoded u64 value. + async fn read_varint_u64(&mut self) -> io::Result> { + let mut result: u64 = 0; + let mut shift: u32 = 0; + + loop { + // We can only shift up to 63 bits (for a u64) + if shift >= 64 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Varint is too large for u64", + )); + } + + // Read a single byte + let res = self.read_u8().await; + if shift == 0 { + if let Err(cause) = res { + if cause.kind() == io::ErrorKind::UnexpectedEof { + return Ok(None); + } else { + return Err(cause); + } + } + } + + let byte = res?; + + // Extract the 7 value bits (bits 0-6, excluding the MSB which is the continuation bit) + let value = (byte & 0x7F) as u64; + + // Add the bits to our result at the current shift position + result |= value << shift; + + // If the high bit is not set (0), this is the last byte + if byte & 0x80 == 0 { + break; + } + + // Move to the next 7 bits + shift += 7; + } + + Ok(Some(result)) + } } + +impl RecvStreamExt for R {} diff --git a/src/provider/events.rs b/src/provider/events.rs index fff800dc9..06cce4c29 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -95,15 +95,21 @@ impl From for io::Error { } } -impl ProgressError { - pub fn code(&self) -> quinn::VarInt { +pub trait HasErrorCode { + fn code(&self) -> quinn::VarInt; +} + +impl HasErrorCode for ProgressError { + fn code(&self) -> quinn::VarInt { match self { ProgressError::Limit => ERR_LIMIT, ProgressError::Permission => ERR_PERMISSION, ProgressError::Internal { .. } => ERR_INTERNAL, } } +} +impl ProgressError { pub fn reason(&self) -> &'static [u8] { match self { ProgressError::Limit => b"limit", diff --git a/src/tests.rs b/src/tests.rs index 0b3b2fde2..9e5ea89f2 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -579,7 +579,7 @@ async fn two_nodes_hash_seq_progress() -> TestResult<()> { let root = add_test_hash_seq(&store1, sizes).await?; let conn = r2.endpoint().connect(addr1, crate::ALPN).await?; let mut stream = store2.remote().fetch(conn, root).stream(); - while let Some(_) = stream.next().await {} + while stream.next().await.is_some() {} check_presence(&store2, &sizes).await?; Ok(()) } @@ -647,7 +647,7 @@ async fn node_serve_blobs() -> TestResult<()> { let expected = test_data(size); let hash = Hash::new(&expected); let mut stream = get::request::get_blob(conn.clone(), hash); - while let Some(_) = stream.next().await {} + while stream.next().await.is_some() {} let actual = get::request::get_blob(conn.clone(), hash).await?; assert_eq!(actual.len(), expected.len(), "size: {size}"); } From 4c4a5e77145aeb6bc469f353f8b0941749640928 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Mon, 8 Sep 2025 12:23:17 +0300 Subject: [PATCH 26/35] Add example how to add compression to the entire blobs protocol. --- Cargo.lock | 97 ++++++++++++++++++++ Cargo.toml | 4 +- examples/compression.rs | 191 ++++++++++++++++++++++++++++++++++++++++ src/get.rs | 22 ++++- src/provider.rs | 19 ++-- src/provider/events.rs | 5 +- 6 files changed, 328 insertions(+), 10 deletions(-) create mode 100644 examples/compression.rs diff --git a/Cargo.lock b/Cargo.lock index 988d7955a..625f30b7b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -156,6 +156,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-compression" +version = "0.4.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "977eb15ea9efd848bb8a4a1a2500347ed7f0bf794edf0dc3ddcf439f43d36b23" +dependencies = [ + "compression-codecs", + "compression-core", + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "async-trait" version = "0.1.88" @@ -374,6 +387,8 @@ version = "1.2.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -508,6 +523,24 @@ dependencies = [ "memchr", ] +[[package]] +name = "compression-codecs" +version = "0.4.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "485abf41ac0c8047c07c87c72c8fb3eb5197f6e9d7ded615dfd1a00ae00a0f64" +dependencies = [ + "compression-core", + "lz4", + "zstd", + "zstd-safe", +] + +[[package]] +name = "compression-core" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e47641d3deaf41fb1538ac1f54735925e275eaf3bf4d55c81b137fba797e5cbb" + [[package]] name = "const-oid" version = "0.9.6" @@ -1741,6 +1774,7 @@ version = "0.93.0" dependencies = [ "anyhow", "arrayvec", + "async-compression", "atomic_refcell", "bao-tree", "bytes", @@ -2008,6 +2042,16 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +[[package]] +name = "jobserver" +version = "0.1.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" +dependencies = [ + "getrandom 0.3.3", + "libc", +] + [[package]] name = "js-sys" version = "0.3.77" @@ -2092,6 +2136,25 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" +[[package]] +name = "lz4" +version = "1.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4" +dependencies = [ + "lz4-sys", +] + +[[package]] +name = "lz4-sys" +version = "1.11.1+lz4-1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "matchers" version = "0.2.0" @@ -2653,6 +2716,12 @@ dependencies = [ "spki", ] +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "pnet_base" version = "0.34.0" @@ -5197,3 +5266,31 @@ dependencies = [ "quote", "syn 2.0.104", ] + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.15+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index bcd5f42d0..a40b735bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ iroh-base = "0.91.1" reflink-copy = "0.1.24" irpc = { version = "0.7.0", features = ["rpc", "quinn_endpoint_setup", "spans", "stream", "derive"], default-features = false } iroh-metrics = { version = "0.35" } +async-compression = { version = "0.4.30", features = ["lz4", "tokio"] } [dev-dependencies] clap = { version = "4.5.31", features = ["derive"] } @@ -60,6 +61,7 @@ tracing-test = "0.2.5" walkdir = "2.5.0" atomic_refcell = "0.1.13" iroh = { version = "0.91.1", features = ["discovery-local-network"]} +async-compression = { version = "0.4.30", features = ["zstd", "tokio"] } [features] hide-proto-docs = [] @@ -68,4 +70,4 @@ default = ["hide-proto-docs"] [patch.crates-io] iroh = { git = "https://github.com/n0-computer/iroh", branch = "main" } -iroh-base = { git = "https://github.com/n0-computer/iroh", branch = "main" } \ No newline at end of file +iroh-base = { git = "https://github.com/n0-computer/iroh", branch = "main" } diff --git a/examples/compression.rs b/examples/compression.rs new file mode 100644 index 000000000..32697254d --- /dev/null +++ b/examples/compression.rs @@ -0,0 +1,191 @@ +/// Example how to limit blob requests by hash and node id, and to add +/// throttling or limiting the maximum number of connections. +/// +/// Limiting is done via a fn that returns an EventSender and internally +/// makes liberal use of spawn to spawn background tasks. +/// +/// This is fine, since the tasks will terminate as soon as the [BlobsProtocol] +/// instance holding the [EventSender] will be dropped. But for production +/// grade code you might nevertheless put the tasks into a [tokio::task::JoinSet] or +/// [n0_future::FuturesUnordered]. +mod common; +use std::{path::PathBuf, time::Instant}; + +use anyhow::Result; +use async_compression::tokio::{bufread::Lz4Decoder, write::Lz4Encoder}; +use bao_tree::blake3; +use clap::Parser; +use common::setup_logging; +use iroh::protocol::ProtocolHandler; +use iroh_blobs::{ + api::Store, + get::fsm::{AtConnected, ConnectedNext, EndBlobNext}, + protocol::{ChunkRangesSeq, GetRequest, Request}, + provider::{ + events::{ClientConnected, EventSender, HasErrorCode}, + handle_get, ErrorHandler, StreamPair, + }, + store::mem::MemStore, + ticket::BlobTicket, +}; +use iroh_io::{TokioStreamReader, TokioStreamWriter}; +use tokio::io::BufReader; +use tracing::debug; + +use crate::common::get_or_generate_secret_key; + +#[derive(Debug, Parser)] +#[command(version, about)] +pub enum Args { + /// Limit requests by node id + Provide { + /// Path for files to add. + path: PathBuf, + }, + /// Get a blob. Just for completeness sake. + Get { + /// Ticket for the blob to download + ticket: BlobTicket, + /// Path to save the blob to + #[clap(long)] + target: Option, + }, +} + +type CompressedWriter = + TokioStreamWriter>; +type CompressedReader = TokioStreamReader< + async_compression::tokio::bufread::Lz4Decoder>, +>; + +#[derive(Debug, Clone)] +struct CompressedBlobsProtocol { + store: Store, + events: EventSender, +} + +impl CompressedBlobsProtocol { + fn new(store: &Store, events: EventSender) -> Self { + Self { + store: store.clone(), + events, + } + } +} + +struct CompressedErrorHandler; + +impl ErrorHandler for CompressedErrorHandler { + type W = CompressedWriter; + + type R = CompressedReader; + + async fn stop(reader: &mut Self::R, code: quinn::VarInt) { + reader.0.get_mut().get_mut().stop(code).ok(); + } + + async fn reset(writer: &mut Self::W, code: quinn::VarInt) { + writer.0.get_mut().reset(code).ok(); + } +} + +impl ProtocolHandler for CompressedBlobsProtocol { + async fn accept( + &self, + connection: iroh::endpoint::Connection, + ) -> std::result::Result<(), iroh::protocol::AcceptError> { + let connection_id = connection.stable_id() as u64; + let node_id = connection.remote_node_id()?; + if let Err(cause) = self + .events + .client_connected(|| ClientConnected { + connection_id, + node_id, + }) + .await + { + connection.close(cause.code(), cause.reason()); + debug!("closing connection: {cause}"); + return Ok(()); + } + while let Ok((send, recv)) = connection.accept_bi().await { + let stream_id = send.id().index(); + let send = TokioStreamWriter(Lz4Encoder::new(send)); + let recv = TokioStreamReader(Lz4Decoder::new(BufReader::new(recv))); + let store = self.store.clone(); + let mut pair = + StreamPair::new(connection_id, stream_id, recv, send, self.events.clone()); + tokio::spawn(async move { + let request = pair.read_request().await?; + if let Request::Get(request) = request { + handle_get::(pair, store, request).await?; + } + anyhow::Ok(()) + }); + } + Ok(()) + } +} + +const ALPN: &[u8] = b"iroh-blobs-compressed/0.1.0"; + +#[tokio::main] +async fn main() -> Result<()> { + setup_logging(); + let args = Args::parse(); + let secret = get_or_generate_secret_key()?; + let endpoint = iroh::Endpoint::builder() + .secret_key(secret) + .discovery_n0() + .bind() + .await?; + match args { + Args::Provide { path } => { + let store = MemStore::new(); + let tag = store.add_path(path).await?; + let blobs = CompressedBlobsProtocol::new(&store, EventSender::DEFAULT); + let router = iroh::protocol::Router::builder(endpoint.clone()) + .accept(ALPN, blobs) + .spawn(); + let ticket = BlobTicket::new(endpoint.node_id().into(), tag.hash, tag.format); + println!("Serving blob with hash {}", tag.hash); + println!("Ticket: {ticket}"); + println!("Node is running. Press Ctrl-C to exit."); + tokio::signal::ctrl_c().await?; + println!("Shutting down."); + router.shutdown().await?; + } + Args::Get { ticket, target } => { + let conn = endpoint.connect(ticket.node_addr().clone(), ALPN).await?; + let (send, recv) = conn.open_bi().await?; + let send = TokioStreamWriter(Lz4Encoder::new(send)); + let recv = TokioStreamReader(Lz4Decoder::new(BufReader::new(recv))); + let request = GetRequest { + hash: ticket.hash(), + ranges: ChunkRangesSeq::root(), + }; + let connected = + AtConnected::new(Instant::now(), recv, send, request, Default::default()); + let ConnectedNext::StartRoot(start) = connected.next().await? else { + unreachable!("expected start root"); + }; + let (end, data) = start.next().concatenate_into_vec().await?; + let EndBlobNext::Closing(closing) = end.next() else { + unreachable!("expected closing"); + }; + let stats = closing.next().await?; + if let Some(target) = target { + tokio::fs::write(&target, &data).await?; + println!( + "Wrote {} bytes to {}", + stats.payload_bytes_read, + target.display() + ); + } else { + let hash = blake3::hash(&data); + println!("Hash: {hash}"); + } + } + } + Ok(()) +} diff --git a/src/get.rs b/src/get.rs index bbe8ad26b..3accc55f4 100644 --- a/src/get.rs +++ b/src/get.rs @@ -341,7 +341,23 @@ pub mod fsm { Write { source: io::Error }, } - impl AtConnected { + impl AtConnected { + pub fn new( + start: Instant, + reader: R, + writer: W, + request: GetRequest, + counters: RequestCounters, + ) -> Self { + Self { + start, + reader, + writer, + request, + counters, + } + } + /// Send the request and move to the next state /// /// The next state will be either `StartRoot` or `StartChild` depending on whether @@ -377,6 +393,10 @@ pub mod fsm { .write_bytes(request_bytes.into()) .await .context(connected_next_error::WriteSnafu)?; + writer + .sync() + .await + .context(connected_next_error::WriteSnafu)?; len }; diff --git a/src/provider.rs b/src/provider.rs index e81d783ad..a79e0ad8f 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -323,7 +323,7 @@ pub async fn handle_connection( while let Ok(context) = StreamPair::accept(&connection, progress.clone()).await { let span = debug_span!("stream", stream_id = %context.request_id); let store = store.clone(); - tokio::spawn(handle_stream(store, context).instrument(span)); + tokio::spawn(handle_stream(context, store).instrument(span)); } progress .connection_closed(|| ConnectionClosed { connection_id }) @@ -400,17 +400,17 @@ impl ErrorHandler for IrohErrorHandler { } } -pub async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result<()> { +pub async fn handle_stream(mut pair: StreamPair, store: Store) -> anyhow::Result<()> { // 1. Decode the request. debug!("reading request"); - let request = context.read_request().await?; + let request = pair.read_request().await?; type H = IrohErrorHandler; match request { - Request::Get(request) => handle_get::(context, store, request).await?, - Request::GetMany(request) => handle_get_many::(context, store, request).await?, - Request::Observe(request) => handle_observe::(context, store, request).await?, - Request::Push(request) => handle_push::(context, store, request).await?, + Request::Get(request) => handle_get::(pair, store, request).await?, + Request::GetMany(request) => handle_get_many::(pair, store, request).await?, + Request::Observe(request) => handle_observe::(pair, store, request).await?, + Request::Push(request) => handle_push::(pair, store, request).await?, _ => {} } Ok(()) @@ -477,6 +477,11 @@ async fn handle_get_impl( send_blob(&store, offset, hash, ranges.clone(), writer).await?; } } + writer + .inner + .sync() + .await + .map_err(|e| HandleGetError::ExportBao { source: e.into() })?; Ok(()) } diff --git a/src/provider/events.rs b/src/provider/events.rs index 06cce4c29..54511f92c 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -339,7 +339,10 @@ impl EventSender { } while let Some(msg) = rx.recv().await { match msg { - ProviderMessage::ClientConnected(_) => todo!(), + ProviderMessage::ClientConnected(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + } ProviderMessage::ClientConnectedNotify(msg) => { trace!("{:?}", msg.inner); } From f3d02e7a94368a606a9e1308067c87df40db6233 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 9 Sep 2025 11:36:38 +0300 Subject: [PATCH 27/35] Working adapters --- examples/{compression.rs => compression.rs_} | 12 +- src/get.rs | 67 +++--- src/provider.rs | 218 ++++++++++++++++++- 3 files changed, 249 insertions(+), 48 deletions(-) rename examples/{compression.rs => compression.rs_} (93%) diff --git a/examples/compression.rs b/examples/compression.rs_ similarity index 93% rename from examples/compression.rs rename to examples/compression.rs_ index 32697254d..0e1e1aead 100644 --- a/examples/compression.rs +++ b/examples/compression.rs_ @@ -22,13 +22,11 @@ use iroh_blobs::{ get::fsm::{AtConnected, ConnectedNext, EndBlobNext}, protocol::{ChunkRangesSeq, GetRequest, Request}, provider::{ - events::{ClientConnected, EventSender, HasErrorCode}, - handle_get, ErrorHandler, StreamPair, + events::{ClientConnected, EventSender, HasErrorCode}, handle_get, AsyncReadRecvStream, AsyncWriteSendStream, ErrorHandler, StreamPair }, store::mem::MemStore, ticket::BlobTicket, }; -use iroh_io::{TokioStreamReader, TokioStreamWriter}; use tokio::io::BufReader; use tracing::debug; @@ -53,8 +51,8 @@ pub enum Args { } type CompressedWriter = - TokioStreamWriter>; -type CompressedReader = TokioStreamReader< + AsyncWriteSendStream>; +type CompressedReader = AsyncReadRecvStream< async_compression::tokio::bufread::Lz4Decoder>, >; @@ -158,8 +156,8 @@ async fn main() -> Result<()> { Args::Get { ticket, target } => { let conn = endpoint.connect(ticket.node_addr().clone(), ALPN).await?; let (send, recv) = conn.open_bi().await?; - let send = TokioStreamWriter(Lz4Encoder::new(send)); - let recv = TokioStreamReader(Lz4Decoder::new(BufReader::new(recv))); + let send = AsyncWriteSendStream(Lz4Encoder::new(send)); + let recv = AsyncReadRecvStream::new(Lz4Decoder::new(BufReader::new(recv))); let request = GetRequest { hash: ticket.hash(), ranges: ChunkRangesSeq::root(), diff --git a/src/get.rs b/src/get.rs index 3accc55f4..1961e7394 100644 --- a/src/get.rs +++ b/src/get.rs @@ -82,8 +82,8 @@ impl AsyncStreamReader for IrohStreamReader { } } -type DefaultReader = IrohStreamReader; -type DefaultWriter = IrohStreamWriter; +type DefaultReader = iroh::endpoint::RecvStream; +type DefaultWriter = iroh::endpoint::SendStream; /// Stats about the transfer. #[derive( @@ -139,14 +139,14 @@ pub mod fsm { }; use derive_more::From; use iroh::endpoint::Connection; - use iroh_io::{AsyncSliceWriter, AsyncStreamReader, AsyncStreamWriter}; + use iroh_io::{AsyncSliceWriter}; use super::*; use crate::{ get::get_error::BadRequestSnafu, protocol::{ GetManyRequest, GetRequest, NonEmptyRequestRangeSpecIter, Request, MAX_MESSAGE_SIZE, - }, + }, provider::{RecvStream, RecvStreamAsyncStreamReader, SendStream}, }; self_cell::self_cell! { @@ -173,17 +173,15 @@ pub mod fsm { counters: RequestCounters, ) -> std::result::Result, GetError> { let start = Instant::now(); - let (writer, reader) = connection + let (mut writer, reader) = connection .open_bi() .await .map_err(|e| OpenSnafu.into_error(e.into()))?; - let reader = IrohStreamReader(reader); - let mut writer = IrohStreamWriter(writer); let request = Request::GetMany(request); let request_bytes = postcard::to_stdvec(&request) .map_err(|source| BadRequestSnafu.into_error(source.into()))?; writer - .write_bytes(request_bytes.into()) + .send_bytes(request_bytes.into()) .await .context(connected_next_error::WriteSnafu)?; let Request::GetMany(request) = request else { @@ -270,8 +268,6 @@ pub mod fsm { .open_bi() .await .map_err(|e| OpenSnafu.into_error(e.into()))?; - let reader = IrohStreamReader(reader); - let writer = IrohStreamWriter(writer); Ok(AtConnected { start, reader, @@ -298,8 +294,8 @@ pub mod fsm { /// State of the get response machine after the handshake has been sent #[derive(Debug)] pub struct AtConnected< - R: AsyncStreamReader = DefaultReader, - W: AsyncStreamWriter = DefaultWriter, + R: RecvStream = DefaultReader, + W: SendStream = DefaultWriter, > { start: Instant, reader: R, @@ -310,7 +306,7 @@ pub mod fsm { /// Possible next states after the handshake has been sent #[derive(Debug, From)] - pub enum ConnectedNext { + pub enum ConnectedNext { /// First response is either a collection or a single blob StartRoot(AtStartRoot), /// First response is a child @@ -341,7 +337,7 @@ pub mod fsm { Write { source: io::Error }, } - impl AtConnected { + impl AtConnected { pub fn new( start: Instant, reader: R, @@ -390,7 +386,7 @@ pub mod fsm { // write the request itself let len = request_bytes.len() as u64; writer - .write_bytes(request_bytes.into()) + .send_bytes(request_bytes.into()) .await .context(connected_next_error::WriteSnafu)?; writer @@ -438,7 +434,7 @@ pub mod fsm { /// State of the get response when we start reading a collection #[derive(Debug)] - pub struct AtStartRoot { + pub struct AtStartRoot { ranges: ChunkRanges, reader: R, misc: Box, @@ -447,14 +443,14 @@ pub mod fsm { /// State of the get response when we start reading a child #[derive(Debug)] - pub struct AtStartChild { + pub struct AtStartChild { ranges: ChunkRanges, reader: R, misc: Box, offset: u64, } - impl AtStartChild { + impl AtStartChild { /// The offset of the child we are currently reading /// /// This must be used to determine the hash needed to call next. @@ -491,7 +487,7 @@ pub mod fsm { } } - impl AtStartRoot { + impl AtStartRoot { /// The ranges we have requested for the child pub fn ranges(&self) -> &ChunkRanges { &self.ranges @@ -522,7 +518,8 @@ pub mod fsm { /// State before reading a size header #[derive(Debug)] - pub struct AtBlobHeader { + pub struct AtBlobHeader + { ranges: ChunkRanges, reader: R, misc: Box, @@ -560,10 +557,10 @@ pub mod fsm { } } - impl AtBlobHeader { + impl AtBlobHeader { /// Read the size header, returning it and going into the `Content` state. pub async fn next(mut self) -> Result<(AtBlobContent, u64), AtBlobHeaderNextError> { - let size = self.reader.read::<8>().await.map_err(|cause| { + let size = self.reader.recv::<8>().await.map_err(|cause| { if cause.kind() == io::ErrorKind::UnexpectedEof { at_blob_header_next_error::NotFoundSnafu.build() } else { @@ -576,7 +573,7 @@ pub mod fsm { self.hash.into(), self.ranges, BaoTree::new(size, IROH_BLOCK_SIZE), - self.reader, + RecvStreamAsyncStreamReader::new(self.reader), ); Ok(( AtBlobContent { @@ -650,8 +647,8 @@ pub mod fsm { /// State while we are reading content #[derive(Debug)] - pub struct AtBlobContent { - stream: ResponseDecoder, + pub struct AtBlobContent { + stream: ResponseDecoder>, misc: Box, } @@ -765,7 +762,7 @@ pub mod fsm { /// The next state after reading a content item #[derive(Debug, From)] - pub enum BlobContentNext { + pub enum BlobContentNext { /// We expect more content More( ( @@ -777,7 +774,7 @@ pub mod fsm { Done(AtEndBlob), } - impl AtBlobContent { + impl AtBlobContent { /// Read the next item, either content, an error, or the end of the blob pub async fn next(self) -> BlobContentNext { match self.stream.next().await { @@ -796,7 +793,7 @@ pub mod fsm { BlobContentNext::More((next, res)) } ResponseDecoderNext::Done(stream) => BlobContentNext::Done(AtEndBlob { - stream, + stream: stream.into_inner(), misc: self.misc, }), } @@ -933,27 +930,27 @@ pub mod fsm { /// Immediately finish the get response without reading further pub fn finish(self) -> AtClosing { - AtClosing::new(self.misc, self.stream.finish(), false) + AtClosing::new(self.misc, self.stream.finish().into_inner(), false) } } /// State after we have read all the content for a blob #[derive(Debug)] - pub struct AtEndBlob { + pub struct AtEndBlob { stream: R, misc: Box, } /// The next state after the end of a blob #[derive(Debug, From)] - pub enum EndBlobNext { + pub enum EndBlobNext { /// Response is expected to have more children MoreChildren(AtStartChild), /// No more children expected Closing(AtClosing), } - impl AtEndBlob { + impl AtEndBlob { /// Read the next child, or finish pub fn next(mut self) -> EndBlobNext { if let Some((offset, ranges)) = self.misc.ranges_iter.next() { @@ -972,13 +969,13 @@ pub mod fsm { /// State when finishing the get response #[derive(Debug)] - pub struct AtClosing { + pub struct AtClosing { misc: Box, reader: R, check_extra_data: bool, } - impl AtClosing { + impl AtClosing { fn new(misc: Box, reader: R, check_extra_data: bool) -> Self { Self { misc, @@ -992,7 +989,7 @@ pub mod fsm { // Shut down the stream let mut reader = self.reader; if self.check_extra_data { - let rest = reader.read_bytes(1).await?; + let rest = reader.recv_bytes(1).await?; if !rest.is_empty() { error!("Unexpected extra data at the end of the stream"); } diff --git a/src/provider.rs b/src/provider.rs index a79e0ad8f..1c6d41356 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -4,21 +4,19 @@ //! to provide data is to just register a [`crate::BlobsProtocol`] protocol //! handler with an [`iroh::Endpoint`](iroh::protocol::Router). use std::{ - fmt::Debug, - future::Future, - io, - time::{Duration, Instant}, + fmt::Debug, future::Future, io, ops::DerefMut, time::{Duration, Instant} }; use anyhow::Result; use bao_tree::ChunkRanges; +use bytes::Bytes; use iroh::endpoint; use iroh_io::{AsyncStreamReader, AsyncStreamWriter}; use n0_future::StreamExt; -use quinn::{ConnectionError, VarInt}; +use quinn::{ConnectionError, ReadExactError, VarInt}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use snafu::Snafu; -use tokio::select; +use tokio::{io::AsyncRead, select}; use tracing::{debug, debug_span, warn, Instrument}; use crate::{ @@ -334,6 +332,214 @@ pub async fn handle_connection( .await } +/// An abstract `iroh::endpoint::SendStream`. +pub trait SendStream: Send { + /// Send bytes to the stream. This takes a `Bytes` because iroh can directly use them. + fn send_bytes(&mut self, bytes: Bytes) -> impl Future> + Send; + /// Send that sends a fixed sized buffer. + fn send(&mut self, buf: &[u8; L]) -> impl Future> + Send; + /// Sync the stream. Not needed for iroh, but needed for intermediate buffered streams such as compression. + fn sync(&mut self) -> impl Future> + Send; + /// Reset the stream with the given error code. + fn reset(&mut self, code: VarInt) -> io::Result<()>; + /// Wait for the stream to be stopped, returning the error code if it was. + fn stopped(&mut self) -> impl Future>> + Send; +} + +/// An abstract `iroh::endpoint::RecvStream`. +pub trait RecvStream: Send { + /// Receive up to `len` bytes from the stream, directly into a `Bytes`. + fn recv_bytes(&mut self, len: usize) -> impl Future> + Send; + /// Receive exactly `len` bytes from the stream, directly into a `Bytes`. + /// + /// This will return an error if the stream ends before `len` bytes are read. + /// + /// Note that this is different from `recv_bytes`, which will return fewer bytes if the stream ends. + fn recv_bytes_exact( + &mut self, + len: usize, + ) -> impl Future> + Send; + /// Receive exactly `L` bytes from the stream, directly into a `[u8; L]`. + fn recv(&mut self) -> impl Future> + Send; + /// Stop the stream with the given error code. + fn stop(&mut self, code: VarInt) -> io::Result<()>; +} + +impl SendStream for iroh::endpoint::SendStream { + async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> { + Ok(self.write_chunk(bytes).await?) + } + + async fn send(&mut self, buf: &[u8; L]) -> io::Result<()> { + Ok(self.write_all(buf).await?) + } + + async fn sync(&mut self) -> io::Result<()> { + Ok(()) + } + + fn reset(&mut self, code: VarInt) -> io::Result<()> { + Ok(self.reset(code)?) + } + + async fn stopped(&mut self) -> io::Result> { + Ok(self.stopped().await?) + } +} + +impl RecvStream for iroh::endpoint::RecvStream { + async fn recv_bytes(&mut self, len: usize) -> io::Result { + let mut buf = vec![0; len]; + match self.read_exact(&mut buf).await { + Err(ReadExactError::FinishedEarly(n)) => { + buf.truncate(n); + } + Err(ReadExactError::ReadError(e)) => { + return Err(e.into()); + } + Ok(()) => {} + }; + Ok(buf.into()) + } + + async fn recv_bytes_exact( + &mut self, + len: usize, + ) -> io::Result { + let mut buf = vec![0; len]; + self.read_exact(&mut buf).await.map_err(|e| { + match e { + ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""), + ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""), + ReadExactError::ReadError(e) => e.into(), + } + })?; + Ok(buf.into()) + } + + async fn recv(&mut self) -> io::Result<[u8; L]> { + let mut buf = [0; L]; + self.read_exact(&mut buf).await.map_err(|e| { + match e { + ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""), + ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""), + ReadExactError::ReadError(e) => e.into(), + } + })?; + Ok(buf) + } + + fn stop(&mut self, code: VarInt) -> io::Result<()> { + Ok(self.stop(code)?) + } +} + +#[derive(Debug)] +pub struct AsyncReadRecvStream(R); + +impl AsyncReadRecvStream { + pub fn new(inner: R) -> Self { + Self(inner) + } + + pub fn into_inner(self) -> R { + self.0 + } +} + +use tokio::io::AsyncReadExt; + +impl> RecvStream for AsyncReadRecvStream { + async fn recv_bytes(&mut self, len: usize) -> io::Result { + let mut res = vec![0; len]; + let mut n = 0; + loop { + let read = self.0.read(&mut res[n..]).await?; + if read == 0 { + res.truncate(n); + break; + } + n += read; + if n == len { + break; + } + } + Ok(res.into()) + } + + async fn recv_bytes_exact( + &mut self, + len: usize, + ) -> io::Result { + let mut res = vec![0; len]; + self.0.read_exact(&mut res).await?; + Ok(res.into()) + } + + async fn recv(&mut self) -> io::Result<[u8; L]> { + let mut res = [0; L]; + self.0.read_exact(&mut res).await?; + Ok(res) + } + + fn stop(&mut self, code: VarInt) -> io::Result<()> { + self.0.deref_mut().stop(code)?; + Ok(()) + } +} + +/// Utility to convert a [tokio::io::AsyncWrite] into an [SendStream]. +#[derive(Debug, Clone)] +pub struct AsyncWriteSendStream(pub W); + +use tokio::io::AsyncWriteExt; + +impl> SendStream for AsyncWriteSendStream { + async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> { + self.0.write_all(&bytes).await + } + + async fn send(&mut self, buf: &[u8; L]) -> io::Result<()> { + self.0.write_all(buf).await + } + + async fn sync(&mut self) -> io::Result<()> { + self.0.flush().await + } + + fn reset(&mut self, code: VarInt) -> io::Result<()> { + self.0.deref_mut().reset(code)?; + Ok(()) + } + + async fn stopped(&mut self) -> io::Result> { + Ok(self.0.deref_mut().stopped().await?) + } +} + +#[derive(Debug)] +pub struct RecvStreamAsyncStreamReader(R); + +impl RecvStreamAsyncStreamReader { + pub fn new(inner: R) -> Self { + Self(inner) + } + + pub fn into_inner(self) -> R { + self.0 + } +} + +impl AsyncStreamReader for RecvStreamAsyncStreamReader { + async fn read_bytes(&mut self, len: usize) -> io::Result { + self.0.recv_bytes_exact(len).await + } + + async fn read(&mut self) -> io::Result<[u8; L]> { + self.0.recv::().await + } +} + /// Describes how to handle errors for a stream. pub trait ErrorHandler { type W: AsyncStreamWriter; From 41284d2a5da358da8629bfc0f1c051353de8742c Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 9 Sep 2025 12:49:31 +0300 Subject: [PATCH 28/35] compression example works again --- examples/{compression.rs_ => compression.rs} | 119 +++++-- src/api/blobs.rs | 27 +- src/api/remote.rs | 7 +- src/get.rs | 58 +--- src/protocol.rs | 5 +- src/provider.rs | 322 ++++++++++++------- 6 files changed, 332 insertions(+), 206 deletions(-) rename examples/{compression.rs_ => compression.rs} (63%) diff --git a/examples/compression.rs_ b/examples/compression.rs similarity index 63% rename from examples/compression.rs_ rename to examples/compression.rs index 0e1e1aead..5d0a05f7a 100644 --- a/examples/compression.rs_ +++ b/examples/compression.rs @@ -9,7 +9,7 @@ /// grade code you might nevertheless put the tasks into a [tokio::task::JoinSet] or /// [n0_future::FuturesUnordered]. mod common; -use std::{path::PathBuf, time::Instant}; +use std::{io, path::PathBuf, time::Instant}; use anyhow::Result; use async_compression::tokio::{bufread::Lz4Decoder, write::Lz4Encoder}; @@ -22,7 +22,8 @@ use iroh_blobs::{ get::fsm::{AtConnected, ConnectedNext, EndBlobNext}, protocol::{ChunkRangesSeq, GetRequest, Request}, provider::{ - events::{ClientConnected, EventSender, HasErrorCode}, handle_get, AsyncReadRecvStream, AsyncWriteSendStream, ErrorHandler, StreamPair + events::{ClientConnected, EventSender, HasErrorCode}, + handle_get, AsyncReadRecvStream, AsyncWriteSendStream, StreamPair, }, store::mem::MemStore, ticket::BlobTicket, @@ -50,11 +51,91 @@ pub enum Args { }, } -type CompressedWriter = - AsyncWriteSendStream>; -type CompressedReader = AsyncReadRecvStream< +struct CompressedWriter(async_compression::tokio::write::Lz4Encoder); +struct CompressedReader( async_compression::tokio::bufread::Lz4Decoder>, ->; +); + +impl iroh_blobs::provider::SendStream for CompressedWriter { + async fn send_bytes(&mut self, bytes: bytes::Bytes) -> io::Result<()> { + AsyncWriteSendStream::new(self).send_bytes(bytes).await + } + + async fn send(&mut self, buf: &[u8; L]) -> io::Result<()> { + AsyncWriteSendStream::new(self).send(buf).await + } + + async fn sync(&mut self) -> io::Result<()> { + AsyncWriteSendStream::new(self).sync().await + } +} + +impl iroh_blobs::provider::RecvStream for CompressedReader { + async fn recv_bytes(&mut self, len: usize) -> io::Result { + AsyncReadRecvStream::new(self).recv_bytes(len).await + } + + async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { + AsyncReadRecvStream::new(self).recv_bytes_exact(len).await + } + + async fn recv(&mut self) -> io::Result<[u8; L]> { + AsyncReadRecvStream::new(self).recv::().await + } +} + +impl tokio::io::AsyncRead for CompressedReader { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +impl tokio::io::AsyncWrite for CompressedWriter { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + std::pin::Pin::new(&mut self.0).poll_shutdown(cx) + } +} + +impl iroh_blobs::provider::SendStreamSpecific for CompressedWriter { + fn reset(&mut self, code: quinn::VarInt) -> io::Result<()> { + self.0.get_mut().reset(code)?; + Ok(()) + } + + async fn stopped(&mut self) -> io::Result> { + let res = self.0.get_mut().stopped().await?; + Ok(res) + } +} + +impl iroh_blobs::provider::RecvStreamSpecific for CompressedReader { + fn stop(&mut self, code: quinn::VarInt) -> io::Result<()> { + self.0.get_mut().get_mut().stop(code)?; + Ok(()) + } +} #[derive(Debug, Clone)] struct CompressedBlobsProtocol { @@ -71,22 +152,6 @@ impl CompressedBlobsProtocol { } } -struct CompressedErrorHandler; - -impl ErrorHandler for CompressedErrorHandler { - type W = CompressedWriter; - - type R = CompressedReader; - - async fn stop(reader: &mut Self::R, code: quinn::VarInt) { - reader.0.get_mut().get_mut().stop(code).ok(); - } - - async fn reset(writer: &mut Self::W, code: quinn::VarInt) { - writer.0.get_mut().reset(code).ok(); - } -} - impl ProtocolHandler for CompressedBlobsProtocol { async fn accept( &self, @@ -108,15 +173,15 @@ impl ProtocolHandler for CompressedBlobsProtocol { } while let Ok((send, recv)) = connection.accept_bi().await { let stream_id = send.id().index(); - let send = TokioStreamWriter(Lz4Encoder::new(send)); - let recv = TokioStreamReader(Lz4Decoder::new(BufReader::new(recv))); + let send = CompressedWriter(Lz4Encoder::new(send)); + let recv = CompressedReader(Lz4Decoder::new(BufReader::new(recv))); let store = self.store.clone(); let mut pair = StreamPair::new(connection_id, stream_id, recv, send, self.events.clone()); tokio::spawn(async move { let request = pair.read_request().await?; if let Request::Get(request) = request { - handle_get::(pair, store, request).await?; + handle_get(pair, store, request).await?; } anyhow::Ok(()) }); @@ -156,8 +221,8 @@ async fn main() -> Result<()> { Args::Get { ticket, target } => { let conn = endpoint.connect(ticket.node_addr().clone(), ALPN).await?; let (send, recv) = conn.open_bi().await?; - let send = AsyncWriteSendStream(Lz4Encoder::new(send)); - let recv = AsyncReadRecvStream::new(Lz4Decoder::new(BufReader::new(recv))); + let send = CompressedWriter(Lz4Encoder::new(send)); + let recv = CompressedReader(Lz4Decoder::new(BufReader::new(recv))); let request = GetRequest { hash: ticket.hash(), ranges: ChunkRangesSeq::root(), diff --git a/src/api/blobs.rs b/src/api/blobs.rs index 830dd2042..9d81f04bb 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -23,7 +23,7 @@ use bao_tree::{ }; use bytes::Bytes; use genawaiter::sync::Gen; -use iroh_io::{AsyncStreamReader, AsyncStreamWriter}; +use iroh_io::AsyncStreamWriter; use irpc::channel::{mpsc, oneshot}; use n0_future::{future, stream, Stream, StreamExt}; use range_collections::{range_set::RangeSetRange, RangeSet2}; @@ -55,7 +55,7 @@ use super::{ }; use crate::{ api::proto::{BatchRequest, ImportByteStreamUpdate}, - provider::events::ClientResult, + provider::{events::ClientResult, RecvStreamAsyncStreamReader}, store::IROH_BLOCK_SIZE, util::temp_tag::TempTag, BlobFormat, Hash, HashAndFormat, @@ -429,13 +429,13 @@ impl Blobs { } #[cfg_attr(feature = "hide-proto-docs", doc(hidden))] - pub async fn import_bao_reader( + pub async fn import_bao_reader( &self, hash: Hash, ranges: ChunkRanges, mut reader: R, ) -> RequestResult { - let size = u64::from_le_bytes(reader.read::<8>().await.map_err(super::Error::other)?); + let size = u64::from_le_bytes(reader.recv::<8>().await.map_err(super::Error::other)?); let Some(size) = NonZeroU64::new(size) else { return if hash == Hash::EMPTY { Ok(reader) @@ -444,7 +444,12 @@ impl Blobs { }; }; let tree = BaoTree::new(size.get(), IROH_BLOCK_SIZE); - let mut decoder = ResponseDecoder::new(hash.into(), ranges, tree, reader); + let mut decoder = ResponseDecoder::new( + hash.into(), + ranges, + tree, + RecvStreamAsyncStreamReader::new(reader), + ); let options = ImportBaoOptions { hash, size }; let handle = self.import_bao_with_opts(options, 32).await?; let driver = async move { @@ -463,7 +468,7 @@ impl Blobs { let fut = async move { handle.rx.await.map_err(io::Error::other)? }; let (reader, res) = tokio::join!(driver, fut); res?; - Ok(reader?) + Ok(reader?.into_inner()) } #[cfg_attr(feature = "hide-proto-docs", doc(hidden))] @@ -1068,7 +1073,7 @@ impl ExportBaoProgress { } /// Write quinn variant that also feeds a progress writer. - pub(crate) async fn write_with_progress( + pub(crate) async fn write_with_progress( self, writer: &mut W, progress: &mut impl WriteProgress, @@ -1080,19 +1085,19 @@ impl ExportBaoProgress { match item { EncodedItem::Size(size) => { progress.send_transfer_started(index, hash, size).await; - writer.write(&size.to_le_bytes()).await?; + writer.send(&size.to_le_bytes()).await?; progress.log_other_write(8); } EncodedItem::Parent(parent) => { - let mut data = vec![0u8; 64]; + let mut data = [0u8; 64]; data[..32].copy_from_slice(parent.pair.0.as_bytes()); data[32..].copy_from_slice(parent.pair.1.as_bytes()); - writer.write(&data).await?; + writer.send(&data).await?; progress.log_other_write(64); } EncodedItem::Leaf(leaf) => { let len = leaf.data.len(); - writer.write_bytes(leaf.data).await?; + writer.send_bytes(leaf.data).await?; progress .notify_payload_write(index, leaf.offset, len) .await?; diff --git a/src/api/remote.rs b/src/api/remote.rs index 8877c6297..ab2a52e86 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -16,7 +16,7 @@ use crate::{ get::{ fsm::DecodeError, get_error::{BadRequestSnafu, LocalFailureSnafu}, - GetError, GetResult, IrohStreamWriter, Stats, + GetError, GetResult, Stats, }, protocol::{ GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType, @@ -593,7 +593,6 @@ impl Remote { let mut request_ranges = request.ranges.iter_infinite(); let root = request.hash; let root_ranges = request_ranges.next().expect("infinite iterator"); - let mut send = IrohStreamWriter(send); if !root_ranges.is_empty() { self.store() .export_bao(root, root_ranges.clone()) @@ -602,7 +601,7 @@ impl Remote { } if request.ranges.is_blob() { // we are done - send.0.finish()?; + send.finish()?; return Ok(Default::default()); } let hash_seq = self.store().get_bytes(root).await?; @@ -617,7 +616,7 @@ impl Remote { .await?; } } - send.0.finish()?; + send.finish()?; Ok(Default::default()) } diff --git a/src/get.rs b/src/get.rs index 1961e7394..bcea4b1b7 100644 --- a/src/get.rs +++ b/src/get.rs @@ -18,20 +18,16 @@ //! [iroh]: https://docs.rs/iroh use std::{ fmt::{self, Debug}, - io, time::{Duration, Instant}, }; use anyhow::Result; use bao_tree::{io::fsm::BaoContentItem, ChunkNum}; use fsm::RequestCounters; -use iroh_io::{AsyncStreamReader, AsyncStreamWriter}; use n0_snafu::SpanTrace; use nested_enum_utils::common_fields; -use quinn::ReadExactError; use serde::{Deserialize, Serialize}; use snafu::{Backtrace, IntoError, ResultExt, Snafu}; -use tokio::io::AsyncWriteExt; use tracing::{debug, error}; use crate::{protocol::ChunkRangesSeq, store::IROH_BLOCK_SIZE, Hash}; @@ -41,47 +37,6 @@ pub mod request; pub(crate) use error::get_error; pub use error::{GetError, GetResult}; -pub struct IrohStreamWriter(pub iroh::endpoint::SendStream); - -impl AsyncStreamWriter for IrohStreamWriter { - async fn write(&mut self, data: &[u8]) -> io::Result<()> { - Ok(self.0.write_all(data).await?) - } - - async fn write_bytes(&mut self, data: bytes::Bytes) -> io::Result<()> { - Ok(self.0.write_chunk(data).await?) - } - - async fn sync(&mut self) -> io::Result<()> { - self.0.flush().await - } -} - -pub struct IrohStreamReader(pub iroh::endpoint::RecvStream); - -impl AsyncStreamReader for IrohStreamReader { - async fn read(&mut self) -> io::Result<[u8; N]> { - let mut buf = [0u8; N]; - match self.0.read_exact(&mut buf).await { - Ok(()) => Ok(buf), - Err(ReadExactError::ReadError(e)) => Err(e.into()), - Err(ReadExactError::FinishedEarly(_)) => Err(io::ErrorKind::UnexpectedEof.into()), - } - } - - async fn read_bytes(&mut self, len: usize) -> io::Result { - let mut buf = vec![0u8; len]; - match self.0.read_exact(&mut buf).await { - Ok(()) => Ok(buf.into()), - Err(ReadExactError::ReadError(e)) => Err(e.into()), - Err(ReadExactError::FinishedEarly(n)) => { - buf.truncate(n); - Ok(buf.into()) - } - } - } -} - type DefaultReader = iroh::endpoint::RecvStream; type DefaultWriter = iroh::endpoint::SendStream; @@ -139,14 +94,15 @@ pub mod fsm { }; use derive_more::From; use iroh::endpoint::Connection; - use iroh_io::{AsyncSliceWriter}; + use iroh_io::AsyncSliceWriter; use super::*; use crate::{ get::get_error::BadRequestSnafu, protocol::{ GetManyRequest, GetRequest, NonEmptyRequestRangeSpecIter, Request, MAX_MESSAGE_SIZE, - }, provider::{RecvStream, RecvStreamAsyncStreamReader, SendStream}, + }, + provider::{RecvStream, RecvStreamAsyncStreamReader, SendStream}, }; self_cell::self_cell! { @@ -293,10 +249,7 @@ pub mod fsm { /// State of the get response machine after the handshake has been sent #[derive(Debug)] - pub struct AtConnected< - R: RecvStream = DefaultReader, - W: SendStream = DefaultWriter, - > { + pub struct AtConnected { start: Instant, reader: R, writer: W, @@ -518,8 +471,7 @@ pub mod fsm { /// State before reading a size header #[derive(Debug)] - pub struct AtBlobHeader - { + pub struct AtBlobHeader { ranges: ChunkRanges, reader: R, misc: Box, diff --git a/src/protocol.rs b/src/protocol.rs index 8aed6539a..8fcd28e9f 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -382,7 +382,6 @@ use bao_tree::{io::round_up_to_chunks, ChunkNum}; use builder::GetRequestBuilder; use derive_more::From; use iroh::endpoint::VarInt; -use iroh_io::AsyncStreamReader; use postcard::experimental::max_size::MaxSize; use range_collections::{range_set::RangeSetEntry, RangeSet2}; use serde::{Deserialize, Serialize}; @@ -447,7 +446,9 @@ pub enum RequestType { } impl Request { - pub async fn read_async(reader: &mut R) -> io::Result<(Self, usize)> { + pub async fn read_async( + reader: &mut R, + ) -> io::Result<(Self, usize)> { let request_type = reader.read_u8().await?; let request_type: RequestType = postcard::from_bytes(std::slice::from_ref(&request_type)) .map_err(|_| { diff --git a/src/provider.rs b/src/provider.rs index 1c6d41356..c42e68202 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -4,7 +4,11 @@ //! to provide data is to just register a [`crate::BlobsProtocol`] protocol //! handler with an [`iroh::Endpoint`](iroh::protocol::Router). use std::{ - fmt::Debug, future::Future, io, ops::DerefMut, time::{Duration, Instant} + fmt::Debug, + future::Future, + io, + ops::DerefMut, + time::{Duration, Instant}, }; use anyhow::Result; @@ -24,7 +28,6 @@ use crate::{ blobs::{Bitfield, WriteProgress}, ExportBaoError, ExportBaoResult, RequestError, Store, }, - get::{IrohStreamReader, IrohStreamWriter}, hashseq::HashSeq, protocol::{ GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL, @@ -38,8 +41,8 @@ use crate::{ pub mod events; use events::EventSender; -type DefaultWriter = IrohStreamWriter; -type DefaultReader = IrohStreamReader; +type DefaultWriter = iroh::endpoint::SendStream; +type DefaultReader = iroh::endpoint::RecvStream; /// Statistics about a successful or failed transfer. #[derive(Debug, Serialize, Deserialize)] @@ -61,7 +64,10 @@ pub struct TransferStats { /// A pair of [`SendStream`] and [`RecvStream`] with additional context data. #[derive(Debug)] -pub struct StreamPair { +pub struct StreamPair< + R: crate::provider::RecvStream = DefaultReader, + W: crate::provider::SendStream = DefaultWriter, +> { t0: Instant, connection_id: u64, request_id: u64, @@ -80,14 +86,14 @@ impl StreamPair { Ok(Self::new( conn.stable_id() as u64, reader.id().into(), - IrohStreamReader(reader), - IrohStreamWriter(writer), + reader, + writer, events, )) } } -impl StreamPair { +impl StreamPair { pub fn new( connection_id: u64, request_id: u64, @@ -266,13 +272,13 @@ impl WriteProgress for WriterContext { /// Wrapper for a [`quinn::SendStream`] with additional per request information. #[derive(Debug)] -pub struct ProgressWriter { +pub struct ProgressWriter { /// The quinn::SendStream to write to pub inner: W, pub(crate) context: WriterContext, } -impl ProgressWriter { +impl ProgressWriter { fn new(inner: W, context: WriterContext) -> Self { Self { inner, context } } @@ -332,22 +338,33 @@ pub async fn handle_connection( .await } +pub trait SendStreamSpecific: Send { + /// Reset the stream with the given error code. + fn reset(&mut self, code: VarInt) -> io::Result<()>; + /// Wait for the stream to be stopped, returning the error code if it was. + fn stopped(&mut self) -> impl Future>> + Send; +} + /// An abstract `iroh::endpoint::SendStream`. -pub trait SendStream: Send { +pub trait SendStream: SendStreamSpecific { /// Send bytes to the stream. This takes a `Bytes` because iroh can directly use them. fn send_bytes(&mut self, bytes: Bytes) -> impl Future> + Send; /// Send that sends a fixed sized buffer. - fn send(&mut self, buf: &[u8; L]) -> impl Future> + Send; + fn send( + &mut self, + buf: &[u8; L], + ) -> impl Future> + Send; /// Sync the stream. Not needed for iroh, but needed for intermediate buffered streams such as compression. fn sync(&mut self) -> impl Future> + Send; - /// Reset the stream with the given error code. - fn reset(&mut self, code: VarInt) -> io::Result<()>; - /// Wait for the stream to be stopped, returning the error code if it was. - fn stopped(&mut self) -> impl Future>> + Send; +} + +pub trait RecvStreamSpecific: Send { + /// Stop the stream with the given error code. + fn stop(&mut self, code: VarInt) -> io::Result<()>; } /// An abstract `iroh::endpoint::RecvStream`. -pub trait RecvStream: Send { +pub trait RecvStream: RecvStreamSpecific { /// Receive up to `len` bytes from the stream, directly into a `Bytes`. fn recv_bytes(&mut self, len: usize) -> impl Future> + Send; /// Receive exactly `len` bytes from the stream, directly into a `Bytes`. @@ -355,14 +372,9 @@ pub trait RecvStream: Send { /// This will return an error if the stream ends before `len` bytes are read. /// /// Note that this is different from `recv_bytes`, which will return fewer bytes if the stream ends. - fn recv_bytes_exact( - &mut self, - len: usize, - ) -> impl Future> + Send; + fn recv_bytes_exact(&mut self, len: usize) -> impl Future> + Send; /// Receive exactly `L` bytes from the stream, directly into a `[u8; L]`. fn recv(&mut self) -> impl Future> + Send; - /// Stop the stream with the given error code. - fn stop(&mut self, code: VarInt) -> io::Result<()>; } impl SendStream for iroh::endpoint::SendStream { @@ -377,7 +389,9 @@ impl SendStream for iroh::endpoint::SendStream { async fn sync(&mut self) -> io::Result<()> { Ok(()) } +} +impl SendStreamSpecific for iroh::endpoint::SendStream { fn reset(&mut self, code: VarInt) -> io::Result<()> { Ok(self.reset(code)?) } @@ -402,38 +416,77 @@ impl RecvStream for iroh::endpoint::RecvStream { Ok(buf.into()) } - async fn recv_bytes_exact( - &mut self, - len: usize, - ) -> io::Result { + async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { let mut buf = vec![0; len]; - self.read_exact(&mut buf).await.map_err(|e| { - match e { - ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""), - ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""), - ReadExactError::ReadError(e) => e.into(), - } + self.read_exact(&mut buf).await.map_err(|e| match e { + ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""), + ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""), + ReadExactError::ReadError(e) => e.into(), })?; Ok(buf.into()) } async fn recv(&mut self) -> io::Result<[u8; L]> { let mut buf = [0; L]; - self.read_exact(&mut buf).await.map_err(|e| { - match e { - ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""), - ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""), - ReadExactError::ReadError(e) => e.into(), - } + self.read_exact(&mut buf).await.map_err(|e| match e { + ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""), + ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""), + ReadExactError::ReadError(e) => e.into(), })?; Ok(buf) } +} +impl RecvStreamSpecific for iroh::endpoint::RecvStream { fn stop(&mut self, code: VarInt) -> io::Result<()> { Ok(self.stop(code)?) } } +impl RecvStream for &mut R { + async fn recv_bytes(&mut self, len: usize) -> io::Result { + self.deref_mut().recv_bytes(len).await + } + + async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { + self.deref_mut().recv_bytes_exact(len).await + } + + async fn recv(&mut self) -> io::Result<[u8; L]> { + self.deref_mut().recv::().await + } +} + +impl RecvStreamSpecific for &mut R { + fn stop(&mut self, code: VarInt) -> io::Result<()> { + self.deref_mut().stop(code) + } +} + +impl SendStream for &mut W { + async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> { + self.deref_mut().send_bytes(bytes).await + } + + async fn send(&mut self, buf: &[u8; L]) -> io::Result<()> { + self.deref_mut().send(buf).await + } + + async fn sync(&mut self) -> io::Result<()> { + self.deref_mut().sync().await + } +} + +impl SendStreamSpecific for &mut W { + fn reset(&mut self, code: VarInt) -> io::Result<()> { + self.deref_mut().reset(code) + } + + async fn stopped(&mut self) -> io::Result> { + self.deref_mut().stopped().await + } +} + #[derive(Debug)] pub struct AsyncReadRecvStream(R); @@ -449,7 +502,7 @@ impl AsyncReadRecvStream { use tokio::io::AsyncReadExt; -impl> RecvStream for AsyncReadRecvStream { +impl RecvStream for AsyncReadRecvStream { async fn recv_bytes(&mut self, len: usize) -> io::Result { let mut res = vec![0; len]; let mut n = 0; @@ -467,10 +520,7 @@ impl Ok(res.into()) } - async fn recv_bytes_exact( - &mut self, - len: usize, - ) -> io::Result { + async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { let mut res = vec![0; len]; self.0.read_exact(&mut res).await?; Ok(res.into()) @@ -481,20 +531,65 @@ impl self.0.read_exact(&mut res).await?; Ok(res) } +} +impl RecvStreamSpecific for AsyncReadRecvStream { fn stop(&mut self, code: VarInt) -> io::Result<()> { - self.0.deref_mut().stop(code)?; + self.0.stop(code) + } +} + +impl RecvStream for Bytes { + async fn recv_bytes(&mut self, len: usize) -> io::Result { + let n = len.min(self.len()); + let res = self.slice(..n); + *self = self.slice(n..); + Ok(res) + } + + async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { + if self.len() < len { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + let res = self.slice(..len); + *self = self.slice(len..); + Ok(res) + } + + async fn recv(&mut self) -> io::Result<[u8; L]> { + if self.len() < L { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + let mut res = [0; L]; + res.copy_from_slice(&self[..L]); + *self = self.slice(L..); + Ok(res) + } +} + +impl RecvStreamSpecific for Bytes { + fn stop(&mut self, _code: VarInt) -> io::Result<()> { Ok(()) } } /// Utility to convert a [tokio::io::AsyncWrite] into an [SendStream]. #[derive(Debug, Clone)] -pub struct AsyncWriteSendStream(pub W); +pub struct AsyncWriteSendStream(W); + +impl AsyncWriteSendStream { + pub fn new(inner: W) -> Self { + Self(inner) + } + + pub fn into_inner(self) -> W { + self.0 + } +} use tokio::io::AsyncWriteExt; -impl> SendStream for AsyncWriteSendStream { +impl SendStream for AsyncWriteSendStream { async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> { self.0.write_all(&bytes).await } @@ -506,14 +601,16 @@ impl io::Result<()> { self.0.flush().await } +} +impl SendStreamSpecific for AsyncWriteSendStream { fn reset(&mut self, code: VarInt) -> io::Result<()> { - self.0.deref_mut().reset(code)?; + self.0.reset(code)?; Ok(()) } async fn stopped(&mut self) -> io::Result> { - Ok(self.0.deref_mut().stopped().await?) + Ok(self.0.stopped().await?) } } @@ -548,20 +645,25 @@ pub trait ErrorHandler { fn reset(writer: &mut Self::W, code: VarInt) -> impl Future; } -async fn handle_read_request_result( - pair: &mut StreamPair, +async fn handle_read_request_result< + R: crate::provider::RecvStream, + W: crate::provider::SendStream, + T, + E: HasErrorCode, +>( + pair: &mut StreamPair, r: Result, ) -> Result { match r { Ok(x) => Ok(x), Err(e) => { - H::reset(&mut pair.writer, e.code()).await; + pair.writer.reset(e.code()).ok(); Err(e) } } } -async fn handle_write_result( - writer: &mut ProgressWriter, +async fn handle_write_result( + writer: &mut ProgressWriter, r: Result, ) -> Result { match r { @@ -570,14 +672,14 @@ async fn handle_write_result( Ok(x) } Err(e) => { - H::reset(&mut writer.inner, e.code()).await; + writer.inner.reset(e.code()).ok(); writer.transfer_aborted().await; Err(e) } } } -async fn handle_read_result( - reader: &mut ProgressReader, +async fn handle_read_result( + reader: &mut ProgressReader, r: Result, ) -> Result { match r { @@ -586,37 +688,23 @@ async fn handle_read_result( Ok(x) } Err(e) => { - H::stop(&mut reader.inner, e.code()).await; + reader.inner.stop(e.code()).ok(); reader.transfer_aborted().await; Err(e) } } } -struct IrohErrorHandler; - -impl ErrorHandler for IrohErrorHandler { - type W = DefaultWriter; - type R = DefaultReader; - - async fn stop(reader: &mut Self::R, code: VarInt) { - reader.0.stop(code).ok(); - } - async fn reset(writer: &mut Self::W, code: VarInt) { - writer.0.reset(code).ok(); - } -} pub async fn handle_stream(mut pair: StreamPair, store: Store) -> anyhow::Result<()> { // 1. Decode the request. debug!("reading request"); let request = pair.read_request().await?; - type H = IrohErrorHandler; match request { - Request::Get(request) => handle_get::(pair, store, request).await?, - Request::GetMany(request) => handle_get_many::(pair, store, request).await?, - Request::Observe(request) => handle_observe::(pair, store, request).await?, - Request::Push(request) => handle_push::(pair, store, request).await?, + Request::Get(request) => handle_get(pair, store, request).await?, + Request::GetMany(request) => handle_get_many(pair, store, request).await?, + Request::Observe(request) => handle_observe(pair, store, request).await?, + Request::Push(request) => handle_push(pair, store, request).await?, _ => {} } Ok(()) @@ -649,7 +737,7 @@ impl HasErrorCode for HandleGetError { /// Handle a single get request. /// /// Requires a database, the request, and a writer. -async fn handle_get_impl( +async fn handle_get_impl( store: Store, request: GetRequest, writer: &mut ProgressWriter, @@ -692,16 +780,16 @@ async fn handle_get_impl( Ok(()) } -pub async fn handle_get( - mut pair: StreamPair, +pub async fn handle_get( + mut pair: StreamPair, store: Store, request: GetRequest, ) -> anyhow::Result<()> { let res = pair.get_request(|| request.clone()).await; - let tracker = handle_read_request_result::(&mut pair, res).await?; + let tracker = handle_read_request_result(&mut pair, res).await?; let mut writer = pair.into_writer(tracker).await?; let res = handle_get_impl(store, request, &mut writer).await; - handle_write_result::(&mut writer, res).await?; + handle_write_result(&mut writer, res).await?; Ok(()) } @@ -725,7 +813,7 @@ impl HasErrorCode for HandleGetManyError { /// Handle a single get request. /// /// Requires a database, the request, and a writer. -async fn handle_get_many_impl( +async fn handle_get_many_impl( store: Store, request: GetManyRequest, writer: &mut ProgressWriter, @@ -740,16 +828,16 @@ async fn handle_get_many_impl( Ok(()) } -pub async fn handle_get_many( - mut pair: StreamPair, +pub async fn handle_get_many( + mut pair: StreamPair, store: Store, request: GetManyRequest, ) -> anyhow::Result<()> { let res = pair.get_many_request(|| request.clone()).await; - let tracker = handle_read_request_result::(&mut pair, res).await?; + let tracker = handle_read_request_result(&mut pair, res).await?; let mut writer = pair.into_writer(tracker).await?; let res = handle_get_many_impl(store, request, &mut writer).await; - handle_write_result::(&mut writer, res).await?; + handle_write_result(&mut writer, res).await?; Ok(()) } @@ -782,7 +870,7 @@ impl HasErrorCode for HandlePushError { /// Handle a single push request. /// /// Requires a database, the request, and a reader. -async fn handle_push_impl( +async fn handle_push_impl( store: Store, request: PushRequest, reader: &mut ProgressReader, @@ -815,21 +903,21 @@ async fn handle_push_impl( Ok(()) } -pub async fn handle_push( - mut pair: StreamPair, +pub async fn handle_push( + mut pair: StreamPair, store: Store, request: PushRequest, ) -> anyhow::Result<()> { let res = pair.push_request(|| request.clone()).await; - let tracker = handle_read_request_result::(&mut pair, res).await?; + let tracker = handle_read_request_result(&mut pair, res).await?; let mut reader = pair.into_reader(tracker).await?; let res = handle_push_impl(store, request, &mut reader).await; - handle_read_result::(&mut reader, res).await?; + handle_read_result(&mut reader, res).await?; Ok(()) } /// Send a blob to the client. -pub(crate) async fn send_blob( +pub(crate) async fn send_blob( store: &Store, index: u64, hash: Hash, @@ -861,10 +949,10 @@ impl HasErrorCode for HandleObserveError { /// Handle a single push request. /// /// Requires a database, the request, and a reader. -async fn handle_observe_impl( +async fn handle_observe_impl( store: Store, request: ObserveRequest, - writer: &mut ProgressWriter, + writer: &mut ProgressWriter, ) -> std::result::Result<(), HandleObserveError> { let mut stream = store .observe(request.hash) @@ -889,7 +977,7 @@ async fn handle_observe_impl( send_observe_item(writer, &diff).await?; old = new; } - _ = writer.inner.0.stopped() => { + _ = writer.inner.stopped() => { debug!("observer closed"); break; } @@ -898,33 +986,35 @@ async fn handle_observe_impl( Ok(()) } -async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> io::Result<()> { - use irpc::util::AsyncWriteVarintExt; +async fn send_observe_item( + writer: &mut ProgressWriter, + item: &Bitfield, +) -> io::Result<()> { let item = ObserveItem::from(item); - let len = writer.inner.0.write_length_prefixed(item).await?; + let len = writer.inner.write_length_prefixed(item).await?; writer.context.log_other_write(len); Ok(()) } -pub async fn handle_observe>( - mut pair: StreamPair, +pub async fn handle_observe( + mut pair: StreamPair, store: Store, request: ObserveRequest, ) -> anyhow::Result<()> { let res = pair.observe_request(|| request.clone()).await; - let tracker = handle_read_request_result::(&mut pair, res).await?; + let tracker = handle_read_request_result(&mut pair, res).await?; let mut writer = pair.into_writer(tracker).await?; let res = handle_observe_impl(store, request, &mut writer).await; - handle_write_result::(&mut writer, res).await?; + handle_write_result(&mut writer, res).await?; Ok(()) } -pub struct ProgressReader { +pub struct ProgressReader { inner: R, context: ReaderContext, } -impl ProgressReader { +impl ProgressReader { async fn transfer_aborted(&self) { self.context .tracker @@ -942,7 +1032,7 @@ impl ProgressReader { } } -pub(crate) trait RecvStreamExt: AsyncStreamReader { +pub(crate) trait RecvStreamExt: crate::provider::RecvStream { async fn expect_eof(&mut self) -> io::Result<()> { match self.read_u8().await { Ok(_) => Err(io::Error::new( @@ -955,7 +1045,7 @@ pub(crate) trait RecvStreamExt: AsyncStreamReader { } async fn read_u8(&mut self) -> io::Result { - let buf = self.read::<1>().await?; + let buf = self.recv::<1>().await?; Ok(buf[0]) } @@ -963,7 +1053,7 @@ pub(crate) trait RecvStreamExt: AsyncStreamReader { &mut self, max_size: usize, ) -> io::Result<(T, usize)> { - let data = self.read_bytes(max_size).await?; + let data = self.recv_bytes(max_size).await?; self.expect_eof().await?; let value = postcard::from_bytes(&data) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; @@ -984,7 +1074,7 @@ pub(crate) trait RecvStreamExt: AsyncStreamReader { )); } let n = n as usize; - let data = self.read_bytes(n).await?; + let data = self.recv_bytes(n).await?; let value = postcard::from_bytes(&data) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; Ok(value) @@ -1044,4 +1134,18 @@ pub(crate) trait RecvStreamExt: AsyncStreamReader { } } -impl RecvStreamExt for R {} +impl RecvStreamExt for R {} + +pub(crate) trait SendStreamExt: crate::provider::SendStream { + async fn write_length_prefixed(&mut self, value: T) -> io::Result { + let size = postcard::experimental::serialized_size(&value) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + let mut buf = Vec::with_capacity(size + 9); + irpc::util::WriteVarintExt::write_length_prefixed(&mut buf, value)?; + let n = buf.len(); + self.send_bytes(buf.into()).await?; + Ok(n) + } +} + +impl SendStreamExt for W {} From 349c36b9a2bc86d3abb59062a4689c5b0bb01e50 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 9 Sep 2025 19:05:16 +0300 Subject: [PATCH 29/35] Generic receive into store --- examples/compression.rs | 151 ++++++++++---------------------------- src/api/downloader.rs | 2 +- src/api/remote.rs | 107 +++++++++++++++++---------- src/provider.rs | 159 +++++++++++++++++++--------------------- 4 files changed, 186 insertions(+), 233 deletions(-) diff --git a/examples/compression.rs b/examples/compression.rs index 5d0a05f7a..5cc9cbf30 100644 --- a/examples/compression.rs +++ b/examples/compression.rs @@ -9,26 +9,24 @@ /// grade code you might nevertheless put the tasks into a [tokio::task::JoinSet] or /// [n0_future::FuturesUnordered]. mod common; -use std::{io, path::PathBuf, time::Instant}; +use std::{io, path::PathBuf}; use anyhow::Result; use async_compression::tokio::{bufread::Lz4Decoder, write::Lz4Encoder}; -use bao_tree::blake3; use clap::Parser; use common::setup_logging; -use iroh::protocol::ProtocolHandler; +use iroh::{endpoint::VarInt, protocol::ProtocolHandler}; use iroh_blobs::{ api::Store, - get::fsm::{AtConnected, ConnectedNext, EndBlobNext}, - protocol::{ChunkRangesSeq, GetRequest, Request}, provider::{ events::{ClientConnected, EventSender, HasErrorCode}, - handle_get, AsyncReadRecvStream, AsyncWriteSendStream, StreamPair, + handle_stream, AsyncReadRecvStream, AsyncWriteSendStream, RecvStreamSpecific, + SendStreamSpecific, StreamPair, }, store::mem::MemStore, ticket::BlobTicket, }; -use tokio::io::BufReader; +use tokio::io::{AsyncRead, AsyncWrite, BufReader}; use tracing::debug; use crate::common::get_or_generate_secret_key; @@ -51,89 +49,35 @@ pub enum Args { }, } -struct CompressedWriter(async_compression::tokio::write::Lz4Encoder); -struct CompressedReader( - async_compression::tokio::bufread::Lz4Decoder>, -); +struct CompressedWriteStream(Lz4Encoder); -impl iroh_blobs::provider::SendStream for CompressedWriter { - async fn send_bytes(&mut self, bytes: bytes::Bytes) -> io::Result<()> { - AsyncWriteSendStream::new(self).send_bytes(bytes).await +impl SendStreamSpecific for CompressedWriteStream { + fn inner(&mut self) -> &mut (impl AsyncWrite + Unpin + Send) { + &mut self.0 } - async fn send(&mut self, buf: &[u8; L]) -> io::Result<()> { - AsyncWriteSendStream::new(self).send(buf).await + fn reset(&mut self, code: VarInt) -> io::Result<()> { + Ok(self.0.get_mut().reset(code)?) } - async fn sync(&mut self) -> io::Result<()> { - AsyncWriteSendStream::new(self).sync().await + async fn stopped(&mut self) -> io::Result> { + Ok(self.0.get_mut().stopped().await?) } } -impl iroh_blobs::provider::RecvStream for CompressedReader { - async fn recv_bytes(&mut self, len: usize) -> io::Result { - AsyncReadRecvStream::new(self).recv_bytes(len).await - } - - async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { - AsyncReadRecvStream::new(self).recv_bytes_exact(len).await - } - - async fn recv(&mut self) -> io::Result<[u8; L]> { - AsyncReadRecvStream::new(self).recv::().await - } -} - -impl tokio::io::AsyncRead for CompressedReader { - fn poll_read( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> std::task::Poll> { - std::pin::Pin::new(&mut self.0).poll_read(cx, buf) - } -} - -impl tokio::io::AsyncWrite for CompressedWriter { - fn poll_write( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &[u8], - ) -> std::task::Poll> { - std::pin::Pin::new(&mut self.0).poll_write(cx, buf) - } - - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - std::pin::Pin::new(&mut self.0).poll_flush(cx) - } - - fn poll_shutdown( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - std::pin::Pin::new(&mut self.0).poll_shutdown(cx) - } -} +struct CompressedReadStream(Lz4Decoder>); -impl iroh_blobs::provider::SendStreamSpecific for CompressedWriter { - fn reset(&mut self, code: quinn::VarInt) -> io::Result<()> { - self.0.get_mut().reset(code)?; - Ok(()) +impl RecvStreamSpecific for CompressedReadStream { + fn inner(&mut self) -> &mut (impl AsyncRead + Unpin + Send) { + &mut self.0 } - async fn stopped(&mut self) -> io::Result> { - let res = self.0.get_mut().stopped().await?; - Ok(res) + fn stop(&mut self, code: VarInt) -> io::Result<()> { + Ok(self.0.get_mut().get_mut().stop(code)?) } -} -impl iroh_blobs::provider::RecvStreamSpecific for CompressedReader { - fn stop(&mut self, code: quinn::VarInt) -> io::Result<()> { - self.0.get_mut().get_mut().stop(code)?; - Ok(()) + fn id(&self) -> u64 { + self.0.get_ref().get_ref().id().index() } } @@ -172,19 +116,13 @@ impl ProtocolHandler for CompressedBlobsProtocol { return Ok(()); } while let Ok((send, recv)) = connection.accept_bi().await { - let stream_id = send.id().index(); - let send = CompressedWriter(Lz4Encoder::new(send)); - let recv = CompressedReader(Lz4Decoder::new(BufReader::new(recv))); + let send = AsyncWriteSendStream::new(CompressedWriteStream(Lz4Encoder::new(send))); + let recv = AsyncReadRecvStream::new(CompressedReadStream(Lz4Decoder::new( + BufReader::new(recv), + ))); let store = self.store.clone(); - let mut pair = - StreamPair::new(connection_id, stream_id, recv, send, self.events.clone()); - tokio::spawn(async move { - let request = pair.read_request().await?; - if let Request::Get(request) = request { - handle_get(pair, store, request).await?; - } - anyhow::Ok(()) - }); + let pair = StreamPair::new(connection_id, recv, send, self.events.clone()); + tokio::spawn(handle_stream(pair, store)); } Ok(()) } @@ -219,34 +157,21 @@ async fn main() -> Result<()> { router.shutdown().await?; } Args::Get { ticket, target } => { + let store = MemStore::new(); let conn = endpoint.connect(ticket.node_addr().clone(), ALPN).await?; + let connection_id = conn.stable_id() as u64; let (send, recv) = conn.open_bi().await?; - let send = CompressedWriter(Lz4Encoder::new(send)); - let recv = CompressedReader(Lz4Decoder::new(BufReader::new(recv))); - let request = GetRequest { - hash: ticket.hash(), - ranges: ChunkRangesSeq::root(), - }; - let connected = - AtConnected::new(Instant::now(), recv, send, request, Default::default()); - let ConnectedNext::StartRoot(start) = connected.next().await? else { - unreachable!("expected start root"); - }; - let (end, data) = start.next().concatenate_into_vec().await?; - let EndBlobNext::Closing(closing) = end.next() else { - unreachable!("expected closing"); - }; - let stats = closing.next().await?; + let send = AsyncWriteSendStream::new(CompressedWriteStream(Lz4Encoder::new(send))); + let recv = AsyncReadRecvStream::new(CompressedReadStream(Lz4Decoder::new( + BufReader::new(recv), + ))); + let sp = StreamPair::new(connection_id, recv, send, EventSender::DEFAULT); + let stats = store.remote().fetch(sp, ticket.hash_and_format()).await?; if let Some(target) = target { - tokio::fs::write(&target, &data).await?; - println!( - "Wrote {} bytes to {}", - stats.payload_bytes_read, - target.display() - ); + let size = store.export(ticket.hash(), &target).await?; + println!("Wrote {} bytes to {}", size, target.display()); } else { - let hash = blake3::hash(&data); - println!("Hash: {hash}"); + println!("Hash: {}", ticket.hash()); } } } diff --git a/src/api/downloader.rs b/src/api/downloader.rs index 8ac188000..601dcca03 100644 --- a/src/api/downloader.rs +++ b/src/api/downloader.rs @@ -456,7 +456,7 @@ async fn execute_get( }; match remote .execute_get_sink( - &conn, + conn.clone(), local.missing(), (&mut progress).with_map(move |x| DownloadProgessItem::Progress(x + local_bytes)), ) diff --git a/src/api/remote.rs b/src/api/remote.rs index ab2a52e86..242d9a5b5 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -14,7 +14,7 @@ use super::blobs::{Bitfield, ExportBaoOptions}; use crate::{ api::{blobs::WriteProgress, ApiClient}, get::{ - fsm::DecodeError, + fsm::{AtConnected, DecodeError}, get_error::{BadRequestSnafu, LocalFailureSnafu}, GetError, GetResult, Stats, }, @@ -22,7 +22,10 @@ use crate::{ GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType, MAX_MESSAGE_SIZE, }, - provider::events::{ClientResult, ProgressError}, + provider::{ + events::{ClientResult, EventSender, ProgressError}, + RecvStream, StreamPair, + }, util::sink::{Sink, TokioMpscSenderSink}, }; @@ -476,7 +479,7 @@ impl Remote { pub fn fetch( &self, - conn: impl GetConnection + Send + 'static, + sp: impl GetStreamPair + Send + 'static, content: impl Into, ) -> GetProgress { let content = content.into(); @@ -485,7 +488,7 @@ impl Remote { let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress); let this = self.clone(); let fut = async move { - let res = this.fetch_sink(conn, content, sink).await.into(); + let res = this.fetch_sink(sp, content, sink).await.into(); tx2.send(res).await.ok(); }; GetProgress { @@ -503,7 +506,7 @@ impl Remote { /// This will return the stats of the download. pub(crate) async fn fetch_sink( &self, - mut conn: impl GetConnection, + sp: impl GetStreamPair, content: impl Into, progress: impl Sink, ) -> GetResult { @@ -516,11 +519,7 @@ impl Remote { return Ok(Default::default()); } let request = local.missing(); - let conn = conn - .connection() - .await - .map_err(|e| LocalFailureSnafu.into_error(e))?; - let stats = self.execute_get_sink(&conn, request, progress).await?; + let stats = self.execute_get_sink(sp, request, progress).await?; Ok(stats) } @@ -620,17 +619,21 @@ impl Remote { Ok(Default::default()) } - pub fn execute_get(&self, conn: Connection, request: GetRequest) -> GetProgress { + pub fn execute_get(&self, conn: impl GetStreamPair, request: GetRequest) -> GetProgress { self.execute_get_with_opts(conn, request) } - pub fn execute_get_with_opts(&self, conn: Connection, request: GetRequest) -> GetProgress { + pub fn execute_get_with_opts( + &self, + conn: impl GetStreamPair, + request: GetRequest, + ) -> GetProgress { let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress); let this = self.clone(); let fut = async move { - let res = this.execute_get_sink(&conn, request, sink).await.into(); + let res = this.execute_get_sink(conn, request, sink).await.into(); tx2.send(res).await.ok(); }; GetProgress { @@ -649,16 +652,24 @@ impl Remote { /// This will return the stats of the download. pub(crate) async fn execute_get_sink( &self, - conn: &Connection, + conn: impl GetStreamPair, request: GetRequest, mut progress: impl Sink, ) -> GetResult { let store = self.store(); let root = request.hash; + let conn = conn.open_stream_pair().await.map_err(|e| { + LocalFailureSnafu.into_error(anyhow::anyhow!("failed to open stream pair: {e}")) + })?; // I am cloning the connection, but it's fine because the original connection or ConnectionRef stays alive // for the duration of the operation. - let start = crate::get::fsm::start(conn.clone(), request, Default::default()); - let connected = start.next().await?; + let connected = AtConnected::new( + conn.t0, + conn.reader, + conn.writer, + request, + Default::default(), + ); trace!("Getting header"); // read the header let next_child = match connected.next().await? { @@ -840,29 +851,51 @@ use crate::{ Hash, HashAndFormat, }; -/// Trait to lazily get a connection -pub trait GetConnection { - fn connection(&mut self) - -> impl Future> + Send + '_; +pub trait GetStreamPair: Send + 'static { + fn open_stream_pair( + self, + ) -> impl Future< + Output = io::Result< + StreamPair, + >, + > + Send + + 'static; } -/// If we already have a connection, the impl is trivial -impl GetConnection for Connection { - fn connection( - &mut self, - ) -> impl Future> + Send + '_ { - let conn = self.clone(); - async { Ok(conn) } +impl + GetStreamPair for StreamPair +{ + fn open_stream_pair( + self, + ) -> impl Future< + Output = io::Result< + StreamPair, + >, + > + Send + + 'static { + async move { Ok(self) } } } -/// If we already have a connection, the impl is trivial -impl GetConnection for &Connection { - fn connection( - &mut self, - ) -> impl Future> + Send + '_ { - let conn = self.clone(); - async { Ok(conn) } +impl GetStreamPair for Connection { + fn open_stream_pair( + self, + ) -> impl Future< + Output = io::Result< + StreamPair, + >, + > + Send + + 'static { + let connection_id = self.stable_id() as u64; + async move { + let (send, recv) = self.open_bi().await?; + Ok(StreamPair::new( + connection_id, + recv, + send, + EventSender::DEFAULT, + )) + } } } @@ -870,12 +903,12 @@ fn get_buffer_size(size: NonZeroU64) -> usize { (size.get() / (IROH_BLOCK_SIZE.bytes() as u64) + 2).min(64) as usize } -async fn get_blob_ranges_impl( - header: AtBlobHeader, +async fn get_blob_ranges_impl( + header: AtBlobHeader, hash: Hash, store: &Store, mut progress: impl Sink, -) -> GetResult { +) -> GetResult> { let (mut content, size) = header.next().await?; let Some(size) = NonZeroU64::new(size) else { return if hash == Hash::EMPTY { diff --git a/src/provider.rs b/src/provider.rs index c42e68202..633cbd5f7 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -7,7 +7,7 @@ use std::{ fmt::Debug, future::Future, io, - ops::DerefMut, + ops::{Deref, DerefMut}, time::{Duration, Instant}, }; @@ -20,7 +20,10 @@ use n0_future::StreamExt; use quinn::{ConnectionError, ReadExactError, VarInt}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use snafu::Snafu; -use tokio::{io::AsyncRead, select}; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + select, +}; use tracing::{debug, debug_span, warn, Instrument}; use crate::{ @@ -64,15 +67,11 @@ pub struct TransferStats { /// A pair of [`SendStream`] and [`RecvStream`] with additional context data. #[derive(Debug)] -pub struct StreamPair< - R: crate::provider::RecvStream = DefaultReader, - W: crate::provider::SendStream = DefaultWriter, -> { - t0: Instant, +pub struct StreamPair { + pub t0: Instant, connection_id: u64, - request_id: u64, - reader: R, - writer: W, + pub reader: R, + pub writer: W, other_bytes_read: u64, events: EventSender, } @@ -83,28 +82,19 @@ impl StreamPair { events: EventSender, ) -> Result { let (writer, reader) = conn.accept_bi().await?; - Ok(Self::new( - conn.stable_id() as u64, - reader.id().into(), - reader, - writer, - events, - )) + Ok(Self::new(conn.stable_id() as u64, reader, writer, events)) } } impl StreamPair { - pub fn new( - connection_id: u64, - request_id: u64, - reader: R, - writer: W, - events: EventSender, - ) -> Self { + pub fn stream_id(&self) -> u64 { + self.reader.id() + } + + pub fn new(connection_id: u64, reader: R, writer: W, events: EventSender) -> Self { Self { t0: Instant::now(), connection_id, - request_id, reader, writer, other_bytes_read: 0, @@ -166,7 +156,7 @@ impl StreamPair< f: impl FnOnce() -> GetRequest, ) -> Result { self.events - .request(f, self.connection_id, self.request_id) + .request(f, self.connection_id, self.reader.id()) .await } @@ -175,7 +165,7 @@ impl StreamPair< f: impl FnOnce() -> GetManyRequest, ) -> Result { self.events - .request(f, self.connection_id, self.request_id) + .request(f, self.connection_id, self.reader.id()) .await } @@ -184,7 +174,7 @@ impl StreamPair< f: impl FnOnce() -> PushRequest, ) -> Result { self.events - .request(f, self.connection_id, self.request_id) + .request(f, self.connection_id, self.reader.id()) .await } @@ -193,7 +183,7 @@ impl StreamPair< f: impl FnOnce() -> ObserveRequest, ) -> Result { self.events - .request(f, self.connection_id, self.request_id) + .request(f, self.connection_id, self.reader.id()) .await } @@ -324,10 +314,10 @@ pub async fn handle_connection( debug!("closing connection: {cause}"); return; } - while let Ok(context) = StreamPair::accept(&connection, progress.clone()).await { - let span = debug_span!("stream", stream_id = %context.request_id); + while let Ok(pair) = StreamPair::accept(&connection, progress.clone()).await { + let span = debug_span!("stream", stream_id = %pair.stream_id()); let store = store.clone(); - tokio::spawn(handle_stream(context, store).instrument(span)); + tokio::spawn(handle_stream(pair, store).instrument(span)); } progress .connection_closed(|| ConnectionClosed { connection_id }) @@ -338,15 +328,8 @@ pub async fn handle_connection( .await } -pub trait SendStreamSpecific: Send { - /// Reset the stream with the given error code. - fn reset(&mut self, code: VarInt) -> io::Result<()>; - /// Wait for the stream to be stopped, returning the error code if it was. - fn stopped(&mut self) -> impl Future>> + Send; -} - /// An abstract `iroh::endpoint::SendStream`. -pub trait SendStream: SendStreamSpecific { +pub trait SendStream: Send { /// Send bytes to the stream. This takes a `Bytes` because iroh can directly use them. fn send_bytes(&mut self, bytes: Bytes) -> impl Future> + Send; /// Send that sends a fixed sized buffer. @@ -356,15 +339,14 @@ pub trait SendStream: SendStreamSpecific { ) -> impl Future> + Send; /// Sync the stream. Not needed for iroh, but needed for intermediate buffered streams such as compression. fn sync(&mut self) -> impl Future> + Send; -} - -pub trait RecvStreamSpecific: Send { - /// Stop the stream with the given error code. - fn stop(&mut self, code: VarInt) -> io::Result<()>; + /// Reset the stream with the given error code. + fn reset(&mut self, code: VarInt) -> io::Result<()>; + /// Wait for the stream to be stopped, returning the error code if it was. + fn stopped(&mut self) -> impl Future>> + Send; } /// An abstract `iroh::endpoint::RecvStream`. -pub trait RecvStream: RecvStreamSpecific { +pub trait RecvStream: Send { /// Receive up to `len` bytes from the stream, directly into a `Bytes`. fn recv_bytes(&mut self, len: usize) -> impl Future> + Send; /// Receive exactly `len` bytes from the stream, directly into a `Bytes`. @@ -375,6 +357,10 @@ pub trait RecvStream: RecvStreamSpecific { fn recv_bytes_exact(&mut self, len: usize) -> impl Future> + Send; /// Receive exactly `L` bytes from the stream, directly into a `[u8; L]`. fn recv(&mut self) -> impl Future> + Send; + /// Stop the stream with the given error code. + fn stop(&mut self, code: VarInt) -> io::Result<()>; + /// Get the stream id. + fn id(&self) -> u64; } impl SendStream for iroh::endpoint::SendStream { @@ -389,9 +375,7 @@ impl SendStream for iroh::endpoint::SendStream { async fn sync(&mut self) -> io::Result<()> { Ok(()) } -} -impl SendStreamSpecific for iroh::endpoint::SendStream { fn reset(&mut self, code: VarInt) -> io::Result<()> { Ok(self.reset(code)?) } @@ -435,12 +419,14 @@ impl RecvStream for iroh::endpoint::RecvStream { })?; Ok(buf) } -} -impl RecvStreamSpecific for iroh::endpoint::RecvStream { fn stop(&mut self, code: VarInt) -> io::Result<()> { Ok(self.stop(code)?) } + + fn id(&self) -> u64 { + self.id().index() + } } impl RecvStream for &mut R { @@ -455,12 +441,14 @@ impl RecvStream for &mut R { async fn recv(&mut self) -> io::Result<[u8; L]> { self.deref_mut().recv::().await } -} -impl RecvStreamSpecific for &mut R { fn stop(&mut self, code: VarInt) -> io::Result<()> { self.deref_mut().stop(code) } + + fn id(&self) -> u64 { + self.deref().id() + } } impl SendStream for &mut W { @@ -475,9 +463,7 @@ impl SendStream for &mut W { async fn sync(&mut self) -> io::Result<()> { self.deref_mut().sync().await } -} -impl SendStreamSpecific for &mut W { fn reset(&mut self, code: VarInt) -> io::Result<()> { self.deref_mut().reset(code) } @@ -494,20 +480,14 @@ impl AsyncReadRecvStream { pub fn new(inner: R) -> Self { Self(inner) } - - pub fn into_inner(self) -> R { - self.0 - } } -use tokio::io::AsyncReadExt; - -impl RecvStream for AsyncReadRecvStream { +impl RecvStream for AsyncReadRecvStream { async fn recv_bytes(&mut self, len: usize) -> io::Result { let mut res = vec![0; len]; let mut n = 0; loop { - let read = self.0.read(&mut res[n..]).await?; + let read = self.0.inner().read(&mut res[n..]).await?; if read == 0 { res.truncate(n); break; @@ -522,21 +502,35 @@ impl RecvStream for AsyncReadRecvStre async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { let mut res = vec![0; len]; - self.0.read_exact(&mut res).await?; + self.0.inner().read_exact(&mut res).await?; Ok(res.into()) } async fn recv(&mut self) -> io::Result<[u8; L]> { let mut res = [0; L]; - self.0.read_exact(&mut res).await?; + self.0.inner().read_exact(&mut res).await?; Ok(res) } -} -impl RecvStreamSpecific for AsyncReadRecvStream { fn stop(&mut self, code: VarInt) -> io::Result<()> { self.0.stop(code) } + + fn id(&self) -> u64 { + self.0.id() + } +} + +pub trait RecvStreamSpecific: Send { + fn inner(&mut self) -> &mut (impl AsyncRead + Unpin + Send); + fn stop(&mut self, code: VarInt) -> io::Result<()>; + fn id(&self) -> u64; +} + +pub trait SendStreamSpecific: Send { + fn inner(&mut self) -> &mut (impl AsyncWrite + Unpin + Send); + fn reset(&mut self, code: VarInt) -> io::Result<()>; + fn stopped(&mut self) -> impl Future>> + Send; } impl RecvStream for Bytes { @@ -565,52 +559,53 @@ impl RecvStream for Bytes { *self = self.slice(L..); Ok(res) } -} -impl RecvStreamSpecific for Bytes { fn stop(&mut self, _code: VarInt) -> io::Result<()> { Ok(()) } + + fn id(&self) -> u64 { + 0 + } } /// Utility to convert a [tokio::io::AsyncWrite] into an [SendStream]. #[derive(Debug, Clone)] pub struct AsyncWriteSendStream(W); -impl AsyncWriteSendStream { +impl AsyncWriteSendStream { pub fn new(inner: W) -> Self { Self(inner) } +} +impl AsyncWriteSendStream { pub fn into_inner(self) -> W { self.0 } } -use tokio::io::AsyncWriteExt; - -impl SendStream for AsyncWriteSendStream { +impl SendStream for AsyncWriteSendStream { async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> { - self.0.write_all(&bytes).await + self.0.inner().write_all(&bytes).await } async fn send(&mut self, buf: &[u8; L]) -> io::Result<()> { - self.0.write_all(buf).await + self.0.inner().write_all(buf).await } async fn sync(&mut self) -> io::Result<()> { - self.0.flush().await + self.0.inner().flush().await } -} -impl SendStreamSpecific for AsyncWriteSendStream { fn reset(&mut self, code: VarInt) -> io::Result<()> { self.0.reset(code)?; Ok(()) } async fn stopped(&mut self) -> io::Result> { - Ok(self.0.stopped().await?) + let res = self.0.stopped().await?; + Ok(res) } } @@ -695,11 +690,11 @@ async fn handle_read_result( } } -pub async fn handle_stream(mut pair: StreamPair, store: Store) -> anyhow::Result<()> { - // 1. Decode the request. - debug!("reading request"); +pub async fn handle_stream( + mut pair: StreamPair, + store: Store, +) -> anyhow::Result<()> { let request = pair.read_request().await?; - match request { Request::Get(request) => handle_get(pair, store, request).await?, Request::GetMany(request) => handle_get_many(pair, store, request).await?, From bc159cace4a7b2dfbda7fe8da479e92a9bf292eb Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 11 Sep 2025 14:27:12 +0300 Subject: [PATCH 30/35] More moving stuff around --- Cargo.lock | 7 + Cargo.toml | 1 + examples/compression.rs | 138 ++++++++---- src/api/blobs.rs | 8 +- src/api/remote.rs | 127 +++++------ src/get.rs | 27 ++- src/protocol.rs | 4 +- src/provider.rs | 478 +++------------------------------------- src/util.rs | 6 + src/util/stream.rs | 433 ++++++++++++++++++++++++++++++++++++ 10 files changed, 650 insertions(+), 579 deletions(-) create mode 100644 src/util/stream.rs diff --git a/Cargo.lock b/Cargo.lock index 625f30b7b..8f800e160 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -541,6 +541,12 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e47641d3deaf41fb1538ac1f54735925e275eaf3bf4d55c81b137fba797e5cbb" +[[package]] +name = "concat_const" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60c92cd5ec953d0542f48d2a90a25aa2828ab1c03217c1ca077000f3af15997d" + [[package]] name = "const-oid" version = "0.9.6" @@ -1780,6 +1786,7 @@ dependencies = [ "bytes", "chrono", "clap", + "concat_const", "data-encoding", "derive_more 2.0.1", "futures-lite", diff --git a/Cargo.toml b/Cargo.toml index a40b735bd..ef0e01d4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,6 +62,7 @@ walkdir = "2.5.0" atomic_refcell = "0.1.13" iroh = { version = "0.91.1", features = ["discovery-local-network"]} async-compression = { version = "0.4.30", features = ["zstd", "tokio"] } +concat_const = "0.2.0" [features] hide-proto-docs = [] diff --git a/examples/compression.rs b/examples/compression.rs index 5cc9cbf30..af7dd77ef 100644 --- a/examples/compression.rs +++ b/examples/compression.rs @@ -9,24 +9,23 @@ /// grade code you might nevertheless put the tasks into a [tokio::task::JoinSet] or /// [n0_future::FuturesUnordered]. mod common; -use std::{io, path::PathBuf}; +use std::{fmt::Debug, path::PathBuf}; use anyhow::Result; -use async_compression::tokio::{bufread::Lz4Decoder, write::Lz4Encoder}; use clap::Parser; use common::setup_logging; -use iroh::{endpoint::VarInt, protocol::ProtocolHandler}; +use iroh::protocol::ProtocolHandler; use iroh_blobs::{ api::Store, + get::StreamPair, provider::{ + self, events::{ClientConnected, EventSender, HasErrorCode}, - handle_stream, AsyncReadRecvStream, AsyncWriteSendStream, RecvStreamSpecific, - SendStreamSpecific, StreamPair, + handle_stream, }, store::mem::MemStore, ticket::BlobTicket, }; -use tokio::io::{AsyncRead, AsyncWrite, BufReader}; use tracing::debug; use crate::common::get_or_generate_secret_key; @@ -49,54 +48,110 @@ pub enum Args { }, } -struct CompressedWriteStream(Lz4Encoder); +trait Compression: Clone + Send + Sync + Debug + 'static { + const ALPN: &'static [u8]; + fn recv_stream( + &self, + stream: iroh::endpoint::RecvStream, + ) -> impl iroh_blobs::util::RecvStream + Sync + 'static; + fn send_stream( + &self, + stream: iroh::endpoint::SendStream, + ) -> impl iroh_blobs::util::SendStream + Sync + 'static; +} -impl SendStreamSpecific for CompressedWriteStream { - fn inner(&mut self) -> &mut (impl AsyncWrite + Unpin + Send) { - &mut self.0 - } +mod lz4 { + use std::io; + + use async_compression::tokio::{bufread::Lz4Decoder, write::Lz4Encoder}; + use iroh::endpoint::VarInt; + use iroh_blobs::util::{ + AsyncReadRecvStream, AsyncWriteSendStream, RecvStreamSpecific, SendStreamSpecific, + }; + use tokio::io::{AsyncRead, AsyncWrite, BufReader}; + + struct SendStream(Lz4Encoder); - fn reset(&mut self, code: VarInt) -> io::Result<()> { - Ok(self.0.get_mut().reset(code)?) + impl SendStream { + pub fn new(inner: iroh::endpoint::SendStream) -> AsyncWriteSendStream { + AsyncWriteSendStream::new(Self(Lz4Encoder::new(inner))) + } } - async fn stopped(&mut self) -> io::Result> { - Ok(self.0.get_mut().stopped().await?) + impl SendStreamSpecific for SendStream { + fn inner(&mut self) -> &mut (impl AsyncWrite + Unpin + Send) { + &mut self.0 + } + + fn reset(&mut self, code: VarInt) -> io::Result<()> { + Ok(self.0.get_mut().reset(code)?) + } + + async fn stopped(&mut self) -> io::Result> { + Ok(self.0.get_mut().stopped().await?) + } } -} -struct CompressedReadStream(Lz4Decoder>); + struct RecvStream(Lz4Decoder>); -impl RecvStreamSpecific for CompressedReadStream { - fn inner(&mut self) -> &mut (impl AsyncRead + Unpin + Send) { - &mut self.0 + impl RecvStream { + pub fn new(inner: iroh::endpoint::RecvStream) -> AsyncReadRecvStream { + AsyncReadRecvStream::new(Self(Lz4Decoder::new(BufReader::new(inner)))) + } } - fn stop(&mut self, code: VarInt) -> io::Result<()> { - Ok(self.0.get_mut().get_mut().stop(code)?) + impl RecvStreamSpecific for RecvStream { + fn inner(&mut self) -> &mut (impl AsyncRead + Unpin + Send) { + &mut self.0 + } + + fn stop(&mut self, code: VarInt) -> io::Result<()> { + Ok(self.0.get_mut().get_mut().stop(code)?) + } + + fn id(&self) -> u64 { + self.0.get_ref().get_ref().id().index() + } } - fn id(&self) -> u64 { - self.0.get_ref().get_ref().id().index() + #[derive(Debug, Clone)] + pub struct Compression; + + impl super::Compression for Compression { + const ALPN: &[u8] = concat_const::concat_bytes!(b"lz4/", iroh_blobs::ALPN); + fn recv_stream( + &self, + stream: iroh::endpoint::RecvStream, + ) -> impl iroh_blobs::util::RecvStream + Sync + 'static { + RecvStream::new(stream) + } + fn send_stream( + &self, + stream: iroh::endpoint::SendStream, + ) -> impl iroh_blobs::util::SendStream + Sync + 'static { + SendStream::new(stream) + } } } #[derive(Debug, Clone)] -struct CompressedBlobsProtocol { +struct CompressedBlobsProtocol { store: Store, events: EventSender, + compression: C, } -impl CompressedBlobsProtocol { - fn new(store: &Store, events: EventSender) -> Self { +impl CompressedBlobsProtocol { + fn new(store: &Store, events: EventSender, compression: C) -> Self { Self { store: store.clone(), events, + compression, } } } -impl ProtocolHandler for CompressedBlobsProtocol { +impl ProtocolHandler for CompressedBlobsProtocol { async fn accept( &self, connection: iroh::endpoint::Connection, @@ -116,20 +171,16 @@ impl ProtocolHandler for CompressedBlobsProtocol { return Ok(()); } while let Ok((send, recv)) = connection.accept_bi().await { - let send = AsyncWriteSendStream::new(CompressedWriteStream(Lz4Encoder::new(send))); - let recv = AsyncReadRecvStream::new(CompressedReadStream(Lz4Decoder::new( - BufReader::new(recv), - ))); + let send = self.compression.send_stream(send); + let recv = self.compression.recv_stream(recv); let store = self.store.clone(); - let pair = StreamPair::new(connection_id, recv, send, self.events.clone()); + let pair = provider::StreamPair::new(connection_id, recv, send, self.events.clone()); tokio::spawn(handle_stream(pair, store)); } Ok(()) } } -const ALPN: &[u8] = b"iroh-blobs-compressed/0.1.0"; - #[tokio::main] async fn main() -> Result<()> { setup_logging(); @@ -140,13 +191,14 @@ async fn main() -> Result<()> { .discovery_n0() .bind() .await?; + let compression = lz4::Compression; match args { Args::Provide { path } => { let store = MemStore::new(); let tag = store.add_path(path).await?; - let blobs = CompressedBlobsProtocol::new(&store, EventSender::DEFAULT); + let blobs = CompressedBlobsProtocol::new(&store, EventSender::DEFAULT, compression); let router = iroh::protocol::Router::builder(endpoint.clone()) - .accept(ALPN, blobs) + .accept(lz4::Compression::ALPN, blobs) .spawn(); let ticket = BlobTicket::new(endpoint.node_id().into(), tag.hash, tag.format); println!("Serving blob with hash {}", tag.hash); @@ -158,14 +210,14 @@ async fn main() -> Result<()> { } Args::Get { ticket, target } => { let store = MemStore::new(); - let conn = endpoint.connect(ticket.node_addr().clone(), ALPN).await?; + let conn = endpoint + .connect(ticket.node_addr().clone(), &lz4::Compression::ALPN) + .await?; let connection_id = conn.stable_id() as u64; let (send, recv) = conn.open_bi().await?; - let send = AsyncWriteSendStream::new(CompressedWriteStream(Lz4Encoder::new(send))); - let recv = AsyncReadRecvStream::new(CompressedReadStream(Lz4Decoder::new( - BufReader::new(recv), - ))); - let sp = StreamPair::new(connection_id, recv, send, EventSender::DEFAULT); + let send = compression.send_stream(send); + let recv = compression.recv_stream(recv); + let sp = StreamPair::new(connection_id, recv, send); let stats = store.remote().fetch(sp, ticket.hash_and_format()).await?; if let Some(target) = target { let size = store.export(ticket.hash(), &target).await?; diff --git a/src/api/blobs.rs b/src/api/blobs.rs index 9d81f04bb..4e059c26d 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -55,9 +55,9 @@ use super::{ }; use crate::{ api::proto::{BatchRequest, ImportByteStreamUpdate}, - provider::{events::ClientResult, RecvStreamAsyncStreamReader}, + provider::events::ClientResult, store::IROH_BLOCK_SIZE, - util::temp_tag::TempTag, + util::{temp_tag::TempTag, RecvStreamAsyncStreamReader}, BlobFormat, Hash, HashAndFormat, }; @@ -429,7 +429,7 @@ impl Blobs { } #[cfg_attr(feature = "hide-proto-docs", doc(hidden))] - pub async fn import_bao_reader( + pub async fn import_bao_reader( &self, hash: Hash, ranges: ChunkRanges, @@ -1073,7 +1073,7 @@ impl ExportBaoProgress { } /// Write quinn variant that also feeds a progress writer. - pub(crate) async fn write_with_progress( + pub(crate) async fn write_with_progress( self, writer: &mut W, progress: &mut impl WriteProgress, diff --git a/src/api/remote.rs b/src/api/remote.rs index 242d9a5b5..c95f73353 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -1,32 +1,54 @@ //! API for downloading blobs from a single remote node. //! //! The entry point is the [`Remote`] struct. +use std::{ + collections::BTreeMap, + future::{Future, IntoFuture}, + num::NonZeroU64, + sync::Arc, +}; + +use bao_tree::{ + io::{BaoContentItem, Leaf}, + ChunkNum, ChunkRanges, +}; use genawaiter::sync::{Co, Gen}; -use iroh::endpoint::SendStream; +use iroh::endpoint::Connection; use irpc::util::{AsyncReadVarintExt, WriteVarintExt}; use n0_future::{io, Stream, StreamExt}; use n0_snafu::SpanTrace; use nested_enum_utils::common_fields; use ref_cast::RefCast; use snafu::{Backtrace, IntoError, ResultExt, Snafu}; +use tracing::{debug, trace}; use super::blobs::{Bitfield, ExportBaoOptions}; use crate::{ - api::{blobs::WriteProgress, ApiClient}, + api::{ + self, + blobs::{Blobs, WriteProgress}, + ApiClient, Store, + }, get::{ - fsm::{AtConnected, DecodeError}, + fsm::{ + AtBlobHeader, AtConnected, AtEndBlob, BlobContentNext, ConnectedNext, DecodeError, + EndBlobNext, + }, get_error::{BadRequestSnafu, LocalFailureSnafu}, - GetError, GetResult, Stats, + GetError, GetResult, Stats, StreamPair, }, + hashseq::{HashSeq, HashSeqIter}, protocol::{ - GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType, - MAX_MESSAGE_SIZE, + ChunkRangesSeq, GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, + Request, RequestType, MAX_MESSAGE_SIZE, }, - provider::{ - events::{ClientResult, EventSender, ProgressError}, - RecvStream, StreamPair, + provider::events::{ClientResult, ProgressError}, + store::IROH_BLOCK_SIZE, + util::{ + sink::{Sink, TokioMpscSenderSink}, + RecvStream, SendStream, }, - util::sink::{Sink, TokioMpscSenderSink}, + Hash, HashAndFormat, }; /// API to compute request and to download from remote nodes. @@ -663,13 +685,8 @@ impl Remote { })?; // I am cloning the connection, but it's fine because the original connection or ConnectionRef stays alive // for the duration of the operation. - let connected = AtConnected::new( - conn.t0, - conn.reader, - conn.writer, - request, - Default::default(), - ); + let connected = + AtConnected::new(conn.t0, conn.recv, conn.send, request, Default::default()); trace!("Getting header"); // read the header let next_child = match connected.next().await? { @@ -828,74 +845,29 @@ pub enum ExecuteError { }, } -use std::{ - collections::BTreeMap, - future::{Future, IntoFuture}, - num::NonZeroU64, - sync::Arc, -}; - -use bao_tree::{ - io::{BaoContentItem, Leaf}, - ChunkNum, ChunkRanges, -}; -use iroh::endpoint::Connection; -use tracing::{debug, trace}; - -use crate::{ - api::{self, blobs::Blobs, Store}, - get::fsm::{AtBlobHeader, AtEndBlob, BlobContentNext, ConnectedNext, EndBlobNext}, - hashseq::{HashSeq, HashSeqIter}, - protocol::{ChunkRangesSeq, GetRequest}, - store::IROH_BLOCK_SIZE, - Hash, HashAndFormat, -}; - pub trait GetStreamPair: Send + 'static { fn open_stream_pair( self, - ) -> impl Future< - Output = io::Result< - StreamPair, - >, - > + Send - + 'static; + ) -> impl Future>> + Send + 'static; } -impl - GetStreamPair for StreamPair +impl GetStreamPair + for StreamPair { - fn open_stream_pair( + async fn open_stream_pair( self, - ) -> impl Future< - Output = io::Result< - StreamPair, - >, - > + Send - + 'static { - async move { Ok(self) } + ) -> io::Result> { + Ok(self) } } impl GetStreamPair for Connection { - fn open_stream_pair( + async fn open_stream_pair( self, - ) -> impl Future< - Output = io::Result< - StreamPair, - >, - > + Send - + 'static { + ) -> io::Result> { let connection_id = self.stable_id() as u64; - async move { - let (send, recv) = self.open_bi().await?; - Ok(StreamPair::new( - connection_id, - recv, - send, - EventSender::DEFAULT, - )) - } + let (send, recv) = self.open_bi().await?; + Ok(StreamPair::new(connection_id, recv, send)) } } @@ -1046,20 +1018,23 @@ impl LazyHashSeq { async fn write_push_request( request: PushRequest, - stream: &mut SendStream, + stream: &mut impl SendStream, ) -> anyhow::Result { let mut request_bytes = Vec::new(); request_bytes.push(RequestType::Push as u8); request_bytes.write_length_prefixed(&request).unwrap(); - stream.write_all(&request_bytes).await?; + stream.send_bytes(request_bytes.into()).await?; Ok(request) } -async fn write_observe_request(request: ObserveRequest, stream: &mut SendStream) -> io::Result<()> { +async fn write_observe_request( + request: ObserveRequest, + stream: &mut impl SendStream, +) -> io::Result<()> { let request = Request::Observe(request); let request_bytes = postcard::to_allocvec(&request) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - stream.write_all(&request_bytes).await?; + stream.send_bytes(request_bytes.into()).await?; Ok(()) } diff --git a/src/get.rs b/src/get.rs index bcea4b1b7..234fe5382 100644 --- a/src/get.rs +++ b/src/get.rs @@ -30,7 +30,12 @@ use serde::{Deserialize, Serialize}; use snafu::{Backtrace, IntoError, ResultExt, Snafu}; use tracing::{debug, error}; -use crate::{protocol::ChunkRangesSeq, store::IROH_BLOCK_SIZE, Hash}; +use crate::{ + protocol::ChunkRangesSeq, + store::IROH_BLOCK_SIZE, + util::{RecvStream, SendStream}, + Hash, +}; mod error; pub mod request; @@ -40,6 +45,24 @@ pub use error::{GetError, GetResult}; type DefaultReader = iroh::endpoint::RecvStream; type DefaultWriter = iroh::endpoint::SendStream; +pub struct StreamPair { + pub connection_id: u64, + pub t0: Instant, + pub recv: R, + pub send: W, +} + +impl StreamPair { + pub fn new(connection_id: u64, recv: R, send: W) -> Self { + Self { + t0: Instant::now(), + recv, + send, + connection_id, + } + } +} + /// Stats about the transfer. #[derive( Debug, @@ -102,7 +125,7 @@ pub mod fsm { protocol::{ GetManyRequest, GetRequest, NonEmptyRequestRangeSpecIter, Request, MAX_MESSAGE_SIZE, }, - provider::{RecvStream, RecvStreamAsyncStreamReader, SendStream}, + util::{RecvStream, RecvStreamAsyncStreamReader, SendStream}, }; self_cell::self_cell! { diff --git a/src/protocol.rs b/src/protocol.rs index 8fcd28e9f..db5faf060 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -390,7 +390,7 @@ pub use bao_tree::ChunkRanges; pub use range_spec::{ChunkRangesSeq, NonEmptyRequestRangeSpecIter, RangeSpec}; use snafu::{GenerateImplicitData, Snafu}; -use crate::{api::blobs::Bitfield, provider::RecvStreamExt, BlobFormat, Hash, HashAndFormat}; +use crate::{api::blobs::Bitfield, util::RecvStreamExt, BlobFormat, Hash, HashAndFormat}; /// Maximum message size is limited to 100MiB for now. pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; @@ -446,7 +446,7 @@ pub enum RequestType { } impl Request { - pub async fn read_async( + pub async fn read_async( reader: &mut R, ) -> io::Result<(Self, usize)> { let request_type = reader.read_u8().await?; diff --git a/src/provider.rs b/src/provider.rs index 633cbd5f7..7da6aaf91 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -7,23 +7,18 @@ use std::{ fmt::Debug, future::Future, io, - ops::{Deref, DerefMut}, time::{Duration, Instant}, }; use anyhow::Result; use bao_tree::ChunkRanges; -use bytes::Bytes; use iroh::endpoint; use iroh_io::{AsyncStreamReader, AsyncStreamWriter}; use n0_future::StreamExt; -use quinn::{ConnectionError, ReadExactError, VarInt}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use quinn::{ConnectionError, VarInt}; +use serde::{Deserialize, Serialize}; use snafu::Snafu; -use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, - select, -}; +use tokio::select; use tracing::{debug, debug_span, warn, Instrument}; use crate::{ @@ -39,6 +34,7 @@ use crate::{ ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError, RequestTracker, }, + util::{RecvStream, SendStream, SendStreamExt, RecvStreamExt}, Hash, }; pub mod events; @@ -68,10 +64,10 @@ pub struct TransferStats { /// A pair of [`SendStream`] and [`RecvStream`] with additional context data. #[derive(Debug)] pub struct StreamPair { - pub t0: Instant, + t0: Instant, connection_id: u64, - pub reader: R, - pub writer: W, + reader: R, + writer: W, other_bytes_read: u64, events: EventSender, } @@ -86,7 +82,7 @@ impl StreamPair { } } -impl StreamPair { +impl StreamPair { pub fn stream_id(&self) -> u64 { self.reader.id() } @@ -262,13 +258,13 @@ impl WriteProgress for WriterContext { /// Wrapper for a [`quinn::SendStream`] with additional per request information. #[derive(Debug)] -pub struct ProgressWriter { +pub struct ProgressWriter { /// The quinn::SendStream to write to pub inner: W, pub(crate) context: WriterContext, } -impl ProgressWriter { +impl ProgressWriter { fn new(inner: W, context: WriterContext) -> Self { Self { inner, context } } @@ -328,310 +324,6 @@ pub async fn handle_connection( .await } -/// An abstract `iroh::endpoint::SendStream`. -pub trait SendStream: Send { - /// Send bytes to the stream. This takes a `Bytes` because iroh can directly use them. - fn send_bytes(&mut self, bytes: Bytes) -> impl Future> + Send; - /// Send that sends a fixed sized buffer. - fn send( - &mut self, - buf: &[u8; L], - ) -> impl Future> + Send; - /// Sync the stream. Not needed for iroh, but needed for intermediate buffered streams such as compression. - fn sync(&mut self) -> impl Future> + Send; - /// Reset the stream with the given error code. - fn reset(&mut self, code: VarInt) -> io::Result<()>; - /// Wait for the stream to be stopped, returning the error code if it was. - fn stopped(&mut self) -> impl Future>> + Send; -} - -/// An abstract `iroh::endpoint::RecvStream`. -pub trait RecvStream: Send { - /// Receive up to `len` bytes from the stream, directly into a `Bytes`. - fn recv_bytes(&mut self, len: usize) -> impl Future> + Send; - /// Receive exactly `len` bytes from the stream, directly into a `Bytes`. - /// - /// This will return an error if the stream ends before `len` bytes are read. - /// - /// Note that this is different from `recv_bytes`, which will return fewer bytes if the stream ends. - fn recv_bytes_exact(&mut self, len: usize) -> impl Future> + Send; - /// Receive exactly `L` bytes from the stream, directly into a `[u8; L]`. - fn recv(&mut self) -> impl Future> + Send; - /// Stop the stream with the given error code. - fn stop(&mut self, code: VarInt) -> io::Result<()>; - /// Get the stream id. - fn id(&self) -> u64; -} - -impl SendStream for iroh::endpoint::SendStream { - async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> { - Ok(self.write_chunk(bytes).await?) - } - - async fn send(&mut self, buf: &[u8; L]) -> io::Result<()> { - Ok(self.write_all(buf).await?) - } - - async fn sync(&mut self) -> io::Result<()> { - Ok(()) - } - - fn reset(&mut self, code: VarInt) -> io::Result<()> { - Ok(self.reset(code)?) - } - - async fn stopped(&mut self) -> io::Result> { - Ok(self.stopped().await?) - } -} - -impl RecvStream for iroh::endpoint::RecvStream { - async fn recv_bytes(&mut self, len: usize) -> io::Result { - let mut buf = vec![0; len]; - match self.read_exact(&mut buf).await { - Err(ReadExactError::FinishedEarly(n)) => { - buf.truncate(n); - } - Err(ReadExactError::ReadError(e)) => { - return Err(e.into()); - } - Ok(()) => {} - }; - Ok(buf.into()) - } - - async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { - let mut buf = vec![0; len]; - self.read_exact(&mut buf).await.map_err(|e| match e { - ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""), - ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""), - ReadExactError::ReadError(e) => e.into(), - })?; - Ok(buf.into()) - } - - async fn recv(&mut self) -> io::Result<[u8; L]> { - let mut buf = [0; L]; - self.read_exact(&mut buf).await.map_err(|e| match e { - ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""), - ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""), - ReadExactError::ReadError(e) => e.into(), - })?; - Ok(buf) - } - - fn stop(&mut self, code: VarInt) -> io::Result<()> { - Ok(self.stop(code)?) - } - - fn id(&self) -> u64 { - self.id().index() - } -} - -impl RecvStream for &mut R { - async fn recv_bytes(&mut self, len: usize) -> io::Result { - self.deref_mut().recv_bytes(len).await - } - - async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { - self.deref_mut().recv_bytes_exact(len).await - } - - async fn recv(&mut self) -> io::Result<[u8; L]> { - self.deref_mut().recv::().await - } - - fn stop(&mut self, code: VarInt) -> io::Result<()> { - self.deref_mut().stop(code) - } - - fn id(&self) -> u64 { - self.deref().id() - } -} - -impl SendStream for &mut W { - async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> { - self.deref_mut().send_bytes(bytes).await - } - - async fn send(&mut self, buf: &[u8; L]) -> io::Result<()> { - self.deref_mut().send(buf).await - } - - async fn sync(&mut self) -> io::Result<()> { - self.deref_mut().sync().await - } - - fn reset(&mut self, code: VarInt) -> io::Result<()> { - self.deref_mut().reset(code) - } - - async fn stopped(&mut self) -> io::Result> { - self.deref_mut().stopped().await - } -} - -#[derive(Debug)] -pub struct AsyncReadRecvStream(R); - -impl AsyncReadRecvStream { - pub fn new(inner: R) -> Self { - Self(inner) - } -} - -impl RecvStream for AsyncReadRecvStream { - async fn recv_bytes(&mut self, len: usize) -> io::Result { - let mut res = vec![0; len]; - let mut n = 0; - loop { - let read = self.0.inner().read(&mut res[n..]).await?; - if read == 0 { - res.truncate(n); - break; - } - n += read; - if n == len { - break; - } - } - Ok(res.into()) - } - - async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { - let mut res = vec![0; len]; - self.0.inner().read_exact(&mut res).await?; - Ok(res.into()) - } - - async fn recv(&mut self) -> io::Result<[u8; L]> { - let mut res = [0; L]; - self.0.inner().read_exact(&mut res).await?; - Ok(res) - } - - fn stop(&mut self, code: VarInt) -> io::Result<()> { - self.0.stop(code) - } - - fn id(&self) -> u64 { - self.0.id() - } -} - -pub trait RecvStreamSpecific: Send { - fn inner(&mut self) -> &mut (impl AsyncRead + Unpin + Send); - fn stop(&mut self, code: VarInt) -> io::Result<()>; - fn id(&self) -> u64; -} - -pub trait SendStreamSpecific: Send { - fn inner(&mut self) -> &mut (impl AsyncWrite + Unpin + Send); - fn reset(&mut self, code: VarInt) -> io::Result<()>; - fn stopped(&mut self) -> impl Future>> + Send; -} - -impl RecvStream for Bytes { - async fn recv_bytes(&mut self, len: usize) -> io::Result { - let n = len.min(self.len()); - let res = self.slice(..n); - *self = self.slice(n..); - Ok(res) - } - - async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { - if self.len() < len { - return Err(io::ErrorKind::UnexpectedEof.into()); - } - let res = self.slice(..len); - *self = self.slice(len..); - Ok(res) - } - - async fn recv(&mut self) -> io::Result<[u8; L]> { - if self.len() < L { - return Err(io::ErrorKind::UnexpectedEof.into()); - } - let mut res = [0; L]; - res.copy_from_slice(&self[..L]); - *self = self.slice(L..); - Ok(res) - } - - fn stop(&mut self, _code: VarInt) -> io::Result<()> { - Ok(()) - } - - fn id(&self) -> u64 { - 0 - } -} - -/// Utility to convert a [tokio::io::AsyncWrite] into an [SendStream]. -#[derive(Debug, Clone)] -pub struct AsyncWriteSendStream(W); - -impl AsyncWriteSendStream { - pub fn new(inner: W) -> Self { - Self(inner) - } -} - -impl AsyncWriteSendStream { - pub fn into_inner(self) -> W { - self.0 - } -} - -impl SendStream for AsyncWriteSendStream { - async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> { - self.0.inner().write_all(&bytes).await - } - - async fn send(&mut self, buf: &[u8; L]) -> io::Result<()> { - self.0.inner().write_all(buf).await - } - - async fn sync(&mut self) -> io::Result<()> { - self.0.inner().flush().await - } - - fn reset(&mut self, code: VarInt) -> io::Result<()> { - self.0.reset(code)?; - Ok(()) - } - - async fn stopped(&mut self) -> io::Result> { - let res = self.0.stopped().await?; - Ok(res) - } -} - -#[derive(Debug)] -pub struct RecvStreamAsyncStreamReader(R); - -impl RecvStreamAsyncStreamReader { - pub fn new(inner: R) -> Self { - Self(inner) - } - - pub fn into_inner(self) -> R { - self.0 - } -} - -impl AsyncStreamReader for RecvStreamAsyncStreamReader { - async fn read_bytes(&mut self, len: usize) -> io::Result { - self.0.recv_bytes_exact(len).await - } - - async fn read(&mut self) -> io::Result<[u8; L]> { - self.0.recv::().await - } -} - /// Describes how to handle errors for a stream. pub trait ErrorHandler { type W: AsyncStreamWriter; @@ -641,8 +333,8 @@ pub trait ErrorHandler { } async fn handle_read_request_result< - R: crate::provider::RecvStream, - W: crate::provider::SendStream, + R: RecvStream, + W: SendStream, T, E: HasErrorCode, >( @@ -657,7 +349,7 @@ async fn handle_read_request_result< } } } -async fn handle_write_result( +async fn handle_write_result( writer: &mut ProgressWriter, r: Result, ) -> Result { @@ -673,7 +365,7 @@ async fn handle_write_result } } } -async fn handle_read_result( +async fn handle_read_result( reader: &mut ProgressReader, r: Result, ) -> Result { @@ -732,7 +424,7 @@ impl HasErrorCode for HandleGetError { /// Handle a single get request. /// /// Requires a database, the request, and a writer. -async fn handle_get_impl( +async fn handle_get_impl( store: Store, request: GetRequest, writer: &mut ProgressWriter, @@ -775,7 +467,7 @@ async fn handle_get_impl( Ok(()) } -pub async fn handle_get( +pub async fn handle_get( mut pair: StreamPair, store: Store, request: GetRequest, @@ -808,7 +500,7 @@ impl HasErrorCode for HandleGetManyError { /// Handle a single get request. /// /// Requires a database, the request, and a writer. -async fn handle_get_many_impl( +async fn handle_get_many_impl( store: Store, request: GetManyRequest, writer: &mut ProgressWriter, @@ -823,7 +515,7 @@ async fn handle_get_many_impl( Ok(()) } -pub async fn handle_get_many( +pub async fn handle_get_many( mut pair: StreamPair, store: Store, request: GetManyRequest, @@ -865,7 +557,7 @@ impl HasErrorCode for HandlePushError { /// Handle a single push request. /// /// Requires a database, the request, and a reader. -async fn handle_push_impl( +async fn handle_push_impl( store: Store, request: PushRequest, reader: &mut ProgressReader, @@ -898,7 +590,7 @@ async fn handle_push_impl( Ok(()) } -pub async fn handle_push( +pub async fn handle_push( mut pair: StreamPair, store: Store, request: PushRequest, @@ -912,7 +604,7 @@ pub async fn handle_push( +pub(crate) async fn send_blob( store: &Store, index: u64, hash: Hash, @@ -944,7 +636,7 @@ impl HasErrorCode for HandleObserveError { /// Handle a single push request. /// /// Requires a database, the request, and a reader. -async fn handle_observe_impl( +async fn handle_observe_impl( store: Store, request: ObserveRequest, writer: &mut ProgressWriter, @@ -981,7 +673,7 @@ async fn handle_observe_impl( Ok(()) } -async fn send_observe_item( +async fn send_observe_item( writer: &mut ProgressWriter, item: &Bitfield, ) -> io::Result<()> { @@ -991,7 +683,7 @@ async fn send_observe_item( Ok(()) } -pub async fn handle_observe( +pub async fn handle_observe( mut pair: StreamPair, store: Store, request: ObserveRequest, @@ -1004,12 +696,12 @@ pub async fn handle_observe { +pub struct ProgressReader { inner: R, context: ReaderContext, } -impl ProgressReader { +impl ProgressReader { async fn transfer_aborted(&self) { self.context .tracker @@ -1026,121 +718,3 @@ impl ProgressReader { .ok(); } } - -pub(crate) trait RecvStreamExt: crate::provider::RecvStream { - async fn expect_eof(&mut self) -> io::Result<()> { - match self.read_u8().await { - Ok(_) => Err(io::Error::new( - io::ErrorKind::InvalidData, - "unexpected data", - )), - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(()), - Err(e) => Err(e), - } - } - - async fn read_u8(&mut self) -> io::Result { - let buf = self.recv::<1>().await?; - Ok(buf[0]) - } - - async fn read_to_end_as( - &mut self, - max_size: usize, - ) -> io::Result<(T, usize)> { - let data = self.recv_bytes(max_size).await?; - self.expect_eof().await?; - let value = postcard::from_bytes(&data) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - Ok((value, data.len())) - } - - async fn read_length_prefixed( - &mut self, - max_size: usize, - ) -> io::Result { - let Some(n) = self.read_varint_u64().await? else { - return Err(io::ErrorKind::UnexpectedEof.into()); - }; - if n > max_size as u64 { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "length prefix too large", - )); - } - let n = n as usize; - let data = self.recv_bytes(n).await?; - let value = postcard::from_bytes(&data) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - Ok(value) - } - - /// Reads a u64 varint from an AsyncRead source, using the Postcard/LEB128 format. - /// - /// In Postcard's varint format (LEB128): - /// - Each byte uses 7 bits for the value - /// - The MSB (most significant bit) of each byte indicates if there are more bytes (1) or not (0) - /// - Values are stored in little-endian order (least significant group first) - /// - /// Returns the decoded u64 value. - async fn read_varint_u64(&mut self) -> io::Result> { - let mut result: u64 = 0; - let mut shift: u32 = 0; - - loop { - // We can only shift up to 63 bits (for a u64) - if shift >= 64 { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "Varint is too large for u64", - )); - } - - // Read a single byte - let res = self.read_u8().await; - if shift == 0 { - if let Err(cause) = res { - if cause.kind() == io::ErrorKind::UnexpectedEof { - return Ok(None); - } else { - return Err(cause); - } - } - } - - let byte = res?; - - // Extract the 7 value bits (bits 0-6, excluding the MSB which is the continuation bit) - let value = (byte & 0x7F) as u64; - - // Add the bits to our result at the current shift position - result |= value << shift; - - // If the high bit is not set (0), this is the last byte - if byte & 0x80 == 0 { - break; - } - - // Move to the next 7 bits - shift += 7; - } - - Ok(Some(result)) - } -} - -impl RecvStreamExt for R {} - -pub(crate) trait SendStreamExt: crate::provider::SendStream { - async fn write_length_prefixed(&mut self, value: T) -> io::Result { - let size = postcard::experimental::serialized_size(&value) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - let mut buf = Vec::with_capacity(size + 9); - irpc::util::WriteVarintExt::write_length_prefixed(&mut buf, value)?; - let n = buf.len(); - self.send_bytes(buf.into()).await?; - Ok(n) - } -} - -impl SendStreamExt for W {} diff --git a/src/util.rs b/src/util.rs index 40abf0343..b90c9d8a5 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,7 +1,13 @@ //! Utilities pub(crate) mod channel; pub mod connection_pool; +mod stream; pub(crate) mod temp_tag; +pub use stream::{ + AsyncReadRecvStream, AsyncWriteSendStream, RecvStream, RecvStreamAsyncStreamReader, + RecvStreamSpecific, SendStream, SendStreamSpecific, +}; +pub(crate) use stream::{SendStreamExt, RecvStreamExt}; pub(crate) mod serde { // Module that handles io::Error serialization/deserialization diff --git a/src/util/stream.rs b/src/util/stream.rs new file mode 100644 index 000000000..db08b0d54 --- /dev/null +++ b/src/util/stream.rs @@ -0,0 +1,433 @@ +use std::{ + future::Future, + io, + ops::{Deref, DerefMut}, +}; + +use bytes::Bytes; +use iroh::endpoint::{ReadExactError, VarInt}; +use iroh_io::AsyncStreamReader; +use serde::{de::DeserializeOwned, Serialize}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +/// An abstract `iroh::endpoint::SendStream`. +pub trait SendStream: Send { + /// Send bytes to the stream. This takes a `Bytes` because iroh can directly use them. + fn send_bytes(&mut self, bytes: Bytes) -> impl Future> + Send; + /// Send that sends a fixed sized buffer. + fn send( + &mut self, + buf: &[u8; L], + ) -> impl Future> + Send; + /// Sync the stream. Not needed for iroh, but needed for intermediate buffered streams such as compression. + fn sync(&mut self) -> impl Future> + Send; + /// Reset the stream with the given error code. + fn reset(&mut self, code: VarInt) -> io::Result<()>; + /// Wait for the stream to be stopped, returning the error code if it was. + fn stopped(&mut self) -> impl Future>> + Send; +} + +/// An abstract `iroh::endpoint::RecvStream`. +pub trait RecvStream: Send { + /// Receive up to `len` bytes from the stream, directly into a `Bytes`. + fn recv_bytes(&mut self, len: usize) -> impl Future> + Send; + /// Receive exactly `len` bytes from the stream, directly into a `Bytes`. + /// + /// This will return an error if the stream ends before `len` bytes are read. + /// + /// Note that this is different from `recv_bytes`, which will return fewer bytes if the stream ends. + fn recv_bytes_exact(&mut self, len: usize) -> impl Future> + Send; + /// Receive exactly `L` bytes from the stream, directly into a `[u8; L]`. + fn recv(&mut self) -> impl Future> + Send; + /// Stop the stream with the given error code. + fn stop(&mut self, code: VarInt) -> io::Result<()>; + /// Get the stream id. + fn id(&self) -> u64; +} + +impl SendStream for iroh::endpoint::SendStream { + async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> { + Ok(self.write_chunk(bytes).await?) + } + + async fn send(&mut self, buf: &[u8; L]) -> io::Result<()> { + Ok(self.write_all(buf).await?) + } + + async fn sync(&mut self) -> io::Result<()> { + Ok(()) + } + + fn reset(&mut self, code: VarInt) -> io::Result<()> { + Ok(self.reset(code)?) + } + + async fn stopped(&mut self) -> io::Result> { + Ok(self.stopped().await?) + } +} + +impl RecvStream for iroh::endpoint::RecvStream { + async fn recv_bytes(&mut self, len: usize) -> io::Result { + let mut buf = vec![0; len]; + match self.read_exact(&mut buf).await { + Err(ReadExactError::FinishedEarly(n)) => { + buf.truncate(n); + } + Err(ReadExactError::ReadError(e)) => { + return Err(e.into()); + } + Ok(()) => {} + }; + Ok(buf.into()) + } + + async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { + let mut buf = vec![0; len]; + self.read_exact(&mut buf).await.map_err(|e| match e { + ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""), + ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""), + ReadExactError::ReadError(e) => e.into(), + })?; + Ok(buf.into()) + } + + async fn recv(&mut self) -> io::Result<[u8; L]> { + let mut buf = [0; L]; + self.read_exact(&mut buf).await.map_err(|e| match e { + ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""), + ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""), + ReadExactError::ReadError(e) => e.into(), + })?; + Ok(buf) + } + + fn stop(&mut self, code: VarInt) -> io::Result<()> { + Ok(self.stop(code)?) + } + + fn id(&self) -> u64 { + self.id().index() + } +} + +impl RecvStream for &mut R { + async fn recv_bytes(&mut self, len: usize) -> io::Result { + self.deref_mut().recv_bytes(len).await + } + + async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { + self.deref_mut().recv_bytes_exact(len).await + } + + async fn recv(&mut self) -> io::Result<[u8; L]> { + self.deref_mut().recv::().await + } + + fn stop(&mut self, code: VarInt) -> io::Result<()> { + self.deref_mut().stop(code) + } + + fn id(&self) -> u64 { + self.deref().id() + } +} + +impl SendStream for &mut W { + async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> { + self.deref_mut().send_bytes(bytes).await + } + + async fn send(&mut self, buf: &[u8; L]) -> io::Result<()> { + self.deref_mut().send(buf).await + } + + async fn sync(&mut self) -> io::Result<()> { + self.deref_mut().sync().await + } + + fn reset(&mut self, code: VarInt) -> io::Result<()> { + self.deref_mut().reset(code) + } + + async fn stopped(&mut self) -> io::Result> { + self.deref_mut().stopped().await + } +} + +#[derive(Debug)] +pub struct AsyncReadRecvStream(R); + +impl AsyncReadRecvStream { + pub fn new(inner: R) -> Self { + Self(inner) + } +} + +impl RecvStream for AsyncReadRecvStream { + async fn recv_bytes(&mut self, len: usize) -> io::Result { + let mut res = vec![0; len]; + let mut n = 0; + loop { + let read = self.0.inner().read(&mut res[n..]).await?; + if read == 0 { + res.truncate(n); + break; + } + n += read; + if n == len { + break; + } + } + Ok(res.into()) + } + + async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { + let mut res = vec![0; len]; + self.0.inner().read_exact(&mut res).await?; + Ok(res.into()) + } + + async fn recv(&mut self) -> io::Result<[u8; L]> { + let mut res = [0; L]; + self.0.inner().read_exact(&mut res).await?; + Ok(res) + } + + fn stop(&mut self, code: VarInt) -> io::Result<()> { + self.0.stop(code) + } + + fn id(&self) -> u64 { + self.0.id() + } +} + +pub trait RecvStreamSpecific: Send { + fn inner(&mut self) -> &mut (impl AsyncRead + Unpin + Send); + fn stop(&mut self, code: VarInt) -> io::Result<()>; + fn id(&self) -> u64; +} + +pub trait SendStreamSpecific: Send { + fn inner(&mut self) -> &mut (impl AsyncWrite + Unpin + Send); + fn reset(&mut self, code: VarInt) -> io::Result<()>; + fn stopped(&mut self) -> impl Future>> + Send; +} + +impl RecvStream for Bytes { + async fn recv_bytes(&mut self, len: usize) -> io::Result { + let n = len.min(self.len()); + let res = self.slice(..n); + *self = self.slice(n..); + Ok(res) + } + + async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { + if self.len() < len { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + let res = self.slice(..len); + *self = self.slice(len..); + Ok(res) + } + + async fn recv(&mut self) -> io::Result<[u8; L]> { + if self.len() < L { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + let mut res = [0; L]; + res.copy_from_slice(&self[..L]); + *self = self.slice(L..); + Ok(res) + } + + fn stop(&mut self, _code: VarInt) -> io::Result<()> { + Ok(()) + } + + fn id(&self) -> u64 { + 0 + } +} + +/// Utility to convert a [tokio::io::AsyncWrite] into an [SendStream]. +#[derive(Debug, Clone)] +pub struct AsyncWriteSendStream(W); + +impl AsyncWriteSendStream { + pub fn new(inner: W) -> Self { + Self(inner) + } +} + +impl AsyncWriteSendStream { + pub fn into_inner(self) -> W { + self.0 + } +} + +impl SendStream for AsyncWriteSendStream { + async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> { + self.0.inner().write_all(&bytes).await + } + + async fn send(&mut self, buf: &[u8; L]) -> io::Result<()> { + self.0.inner().write_all(buf).await + } + + async fn sync(&mut self) -> io::Result<()> { + self.0.inner().flush().await + } + + fn reset(&mut self, code: VarInt) -> io::Result<()> { + self.0.reset(code)?; + Ok(()) + } + + async fn stopped(&mut self) -> io::Result> { + let res = self.0.stopped().await?; + Ok(res) + } +} + +#[derive(Debug)] +pub struct RecvStreamAsyncStreamReader(R); + +impl RecvStreamAsyncStreamReader { + pub fn new(inner: R) -> Self { + Self(inner) + } + + pub fn into_inner(self) -> R { + self.0 + } +} + +impl AsyncStreamReader for RecvStreamAsyncStreamReader { + async fn read_bytes(&mut self, len: usize) -> io::Result { + self.0.recv_bytes_exact(len).await + } + + async fn read(&mut self) -> io::Result<[u8; L]> { + self.0.recv::().await + } +} + +pub(crate) trait RecvStreamExt: RecvStream { + async fn expect_eof(&mut self) -> io::Result<()> { + match self.read_u8().await { + Ok(_) => Err(io::Error::new( + io::ErrorKind::InvalidData, + "unexpected data", + )), + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(()), + Err(e) => Err(e), + } + } + + async fn read_u8(&mut self) -> io::Result { + let buf = self.recv::<1>().await?; + Ok(buf[0]) + } + + async fn read_to_end_as( + &mut self, + max_size: usize, + ) -> io::Result<(T, usize)> { + let data = self.recv_bytes(max_size).await?; + self.expect_eof().await?; + let value = postcard::from_bytes(&data) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok((value, data.len())) + } + + async fn read_length_prefixed( + &mut self, + max_size: usize, + ) -> io::Result { + let Some(n) = self.read_varint_u64().await? else { + return Err(io::ErrorKind::UnexpectedEof.into()); + }; + if n > max_size as u64 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "length prefix too large", + )); + } + let n = n as usize; + let data = self.recv_bytes(n).await?; + let value = postcard::from_bytes(&data) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(value) + } + + /// Reads a u64 varint from an AsyncRead source, using the Postcard/LEB128 format. + /// + /// In Postcard's varint format (LEB128): + /// - Each byte uses 7 bits for the value + /// - The MSB (most significant bit) of each byte indicates if there are more bytes (1) or not (0) + /// - Values are stored in little-endian order (least significant group first) + /// + /// Returns the decoded u64 value. + async fn read_varint_u64(&mut self) -> io::Result> { + let mut result: u64 = 0; + let mut shift: u32 = 0; + + loop { + // We can only shift up to 63 bits (for a u64) + if shift >= 64 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Varint is too large for u64", + )); + } + + // Read a single byte + let res = self.read_u8().await; + if shift == 0 { + if let Err(cause) = res { + if cause.kind() == io::ErrorKind::UnexpectedEof { + return Ok(None); + } else { + return Err(cause); + } + } + } + + let byte = res?; + + // Extract the 7 value bits (bits 0-6, excluding the MSB which is the continuation bit) + let value = (byte & 0x7F) as u64; + + // Add the bits to our result at the current shift position + result |= value << shift; + + // If the high bit is not set (0), this is the last byte + if byte & 0x80 == 0 { + break; + } + + // Move to the next 7 bits + shift += 7; + } + + Ok(Some(result)) + } +} + +impl RecvStreamExt for R {} + +pub(crate) trait SendStreamExt: SendStream { + async fn write_length_prefixed(&mut self, value: T) -> io::Result { + let size = postcard::experimental::serialized_size(&value) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + let mut buf = Vec::with_capacity(size + 9); + irpc::util::WriteVarintExt::write_length_prefixed(&mut buf, value)?; + let n = buf.len(); + self.send_bytes(buf.into()).await?; + Ok(n) + } +} + +impl SendStreamExt for W {} From 1390954224f0de925aee53389983baa5cafffb49 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 11 Sep 2025 14:30:31 +0300 Subject: [PATCH 31/35] clippy --- examples/compression.rs | 16 +++++----------- src/api/remote.rs | 10 +++------- src/provider.rs | 9 ++------- src/util.rs | 2 +- 4 files changed, 11 insertions(+), 26 deletions(-) diff --git a/examples/compression.rs b/examples/compression.rs index af7dd77ef..6966a04b3 100644 --- a/examples/compression.rs +++ b/examples/compression.rs @@ -1,13 +1,7 @@ -/// Example how to limit blob requests by hash and node id, and to add -/// throttling or limiting the maximum number of connections. +/// Example how to use compression with iroh-blobs /// -/// Limiting is done via a fn that returns an EventSender and internally -/// makes liberal use of spawn to spawn background tasks. -/// -/// This is fine, since the tasks will terminate as soon as the [BlobsProtocol] -/// instance holding the [EventSender] will be dropped. But for production -/// grade code you might nevertheless put the tasks into a [tokio::task::JoinSet] or -/// [n0_future::FuturesUnordered]. +/// We create a derived protocol that compresses both requests and responses using lz4 +/// or any other compression algorithm supported by async-compression. mod common; use std::{fmt::Debug, path::PathBuf}; @@ -211,14 +205,14 @@ async fn main() -> Result<()> { Args::Get { ticket, target } => { let store = MemStore::new(); let conn = endpoint - .connect(ticket.node_addr().clone(), &lz4::Compression::ALPN) + .connect(ticket.node_addr().clone(), lz4::Compression::ALPN) .await?; let connection_id = conn.stable_id() as u64; let (send, recv) = conn.open_bi().await?; let send = compression.send_stream(send); let recv = compression.recv_stream(recv); let sp = StreamPair::new(connection_id, recv, send); - let stats = store.remote().fetch(sp, ticket.hash_and_format()).await?; + let _stats = store.remote().fetch(sp, ticket.hash_and_format()).await?; if let Some(target) = target { let size = store.export(ticket.hash(), &target).await?; println!("Wrote {} bytes to {}", size, target.display()); diff --git a/src/api/remote.rs b/src/api/remote.rs index c95f73353..3d443003e 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -501,7 +501,7 @@ impl Remote { pub fn fetch( &self, - sp: impl GetStreamPair + Send + 'static, + sp: impl GetStreamPair + 'static, content: impl Into, ) -> GetProgress { let content = content.into(); @@ -851,12 +851,8 @@ pub trait GetStreamPair: Send + 'static { ) -> impl Future>> + Send + 'static; } -impl GetStreamPair - for StreamPair -{ - async fn open_stream_pair( - self, - ) -> io::Result> { +impl GetStreamPair for StreamPair { + async fn open_stream_pair(self) -> io::Result> { Ok(self) } } diff --git a/src/provider.rs b/src/provider.rs index 7da6aaf91..3a809c7a8 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -34,7 +34,7 @@ use crate::{ ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError, RequestTracker, }, - util::{RecvStream, SendStream, SendStreamExt, RecvStreamExt}, + util::{RecvStream, RecvStreamExt, SendStream, SendStreamExt}, Hash, }; pub mod events; @@ -332,12 +332,7 @@ pub trait ErrorHandler { fn reset(writer: &mut Self::W, code: VarInt) -> impl Future; } -async fn handle_read_request_result< - R: RecvStream, - W: SendStream, - T, - E: HasErrorCode, ->( +async fn handle_read_request_result( pair: &mut StreamPair, r: Result, ) -> Result { diff --git a/src/util.rs b/src/util.rs index b90c9d8a5..a136a5f15 100644 --- a/src/util.rs +++ b/src/util.rs @@ -7,7 +7,7 @@ pub use stream::{ AsyncReadRecvStream, AsyncWriteSendStream, RecvStream, RecvStreamAsyncStreamReader, RecvStreamSpecific, SendStream, SendStreamSpecific, }; -pub(crate) use stream::{SendStreamExt, RecvStreamExt}; +pub(crate) use stream::{RecvStreamExt, SendStreamExt}; pub(crate) mod serde { // Module that handles io::Error serialization/deserialization From 3dc6d97c08c29461942ccd709176255946b01831 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 17 Sep 2025 09:45:40 +0300 Subject: [PATCH 32/35] Remove async-compression dep on compile We only need async-compression as a dev dep for the compression example! --- Cargo.lock | 48 ------------------------------------------------ Cargo.toml | 3 +-- 2 files changed, 1 insertion(+), 50 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8f800e160..d03c698f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -387,8 +387,6 @@ version = "1.2.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7" dependencies = [ - "jobserver", - "libc", "shlex", ] @@ -531,8 +529,6 @@ checksum = "485abf41ac0c8047c07c87c72c8fb3eb5197f6e9d7ded615dfd1a00ae00a0f64" dependencies = [ "compression-core", "lz4", - "zstd", - "zstd-safe", ] [[package]] @@ -2049,16 +2045,6 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" -[[package]] -name = "jobserver" -version = "0.1.33" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" -dependencies = [ - "getrandom 0.3.3", - "libc", -] - [[package]] name = "js-sys" version = "0.3.77" @@ -2723,12 +2709,6 @@ dependencies = [ "spki", ] -[[package]] -name = "pkg-config" -version = "0.3.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" - [[package]] name = "pnet_base" version = "0.34.0" @@ -5273,31 +5253,3 @@ dependencies = [ "quote", "syn 2.0.104", ] - -[[package]] -name = "zstd" -version = "0.13.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" -dependencies = [ - "zstd-safe", -] - -[[package]] -name = "zstd-safe" -version = "7.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" -dependencies = [ - "zstd-sys", -] - -[[package]] -name = "zstd-sys" -version = "2.0.15+zstd.1.5.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" -dependencies = [ - "cc", - "pkg-config", -] diff --git a/Cargo.toml b/Cargo.toml index 17b2f50f0..6df5241ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,6 @@ genawaiter = { version = "0.99.1", features = ["futures03"] } iroh-base = "0.91.1" irpc = { version = "0.7.0", features = ["rpc", "quinn_endpoint_setup", "spans", "stream", "derive"], default-features = false } iroh-metrics = { version = "0.35" } -async-compression = { version = "0.4.30", features = ["lz4", "tokio"] } redb = { version = "=2.4", optional = true } reflink-copy = { version = "0.1.24", optional = true } @@ -61,7 +60,7 @@ tracing-test = "0.2.5" walkdir = "2.5.0" atomic_refcell = "0.1.13" iroh = { version = "0.91.1", features = ["discovery-local-network"]} -async-compression = { version = "0.4.30", features = ["zstd", "tokio"] } +async-compression = { version = "0.4.30", features = ["lz4", "tokio"] } concat_const = "0.2.0" [features] From 6ceacfd3f293e125ba4c05749b4bc1b932114733 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 17 Sep 2025 12:28:44 +0300 Subject: [PATCH 33/35] PR review: added cancellation safety note and id() --- src/util/stream.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/util/stream.rs b/src/util/stream.rs index db08b0d54..9834a2075 100644 --- a/src/util/stream.rs +++ b/src/util/stream.rs @@ -13,6 +13,8 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; /// An abstract `iroh::endpoint::SendStream`. pub trait SendStream: Send { /// Send bytes to the stream. This takes a `Bytes` because iroh can directly use them. + /// + /// This method is not cancellation safe. Even if this does not resolve, some bytes may have been written when previously polled. fn send_bytes(&mut self, bytes: Bytes) -> impl Future> + Send; /// Send that sends a fixed sized buffer. fn send( @@ -25,6 +27,8 @@ pub trait SendStream: Send { fn reset(&mut self, code: VarInt) -> io::Result<()>; /// Wait for the stream to be stopped, returning the error code if it was. fn stopped(&mut self) -> impl Future>> + Send; + /// Get the stream id. + fn id(&self) -> u64; } /// An abstract `iroh::endpoint::RecvStream`. @@ -65,6 +69,10 @@ impl SendStream for iroh::endpoint::SendStream { async fn stopped(&mut self) -> io::Result> { Ok(self.stopped().await?) } + + fn id(&self) -> u64 { + self.id().index() + } } impl RecvStream for iroh::endpoint::RecvStream { @@ -153,6 +161,10 @@ impl SendStream for &mut W { async fn stopped(&mut self) -> io::Result> { self.deref_mut().stopped().await } + + fn id(&self) -> u64 { + self.deref().id() + } } #[derive(Debug)] @@ -289,6 +301,10 @@ impl SendStream for AsyncWriteSendStream { let res = self.0.stopped().await?; Ok(res) } + + fn id(&self) -> u64 { + 0 + } } #[derive(Debug)] From 845e01eac9e0233912f2b73eefd73ac112ce6c00 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 17 Sep 2025 12:41:17 +0300 Subject: [PATCH 34/35] PR review: rename the weirdly named ...Specific traits hopefully this makes it more clear what they are for! --- examples/compression.rs | 11 ++++++++--- src/util.rs | 4 ++-- src/util/stream.rs | 43 +++++++++++++++++++++++++---------------- 3 files changed, 36 insertions(+), 22 deletions(-) diff --git a/examples/compression.rs b/examples/compression.rs index 6189344cc..343209cd8 100644 --- a/examples/compression.rs +++ b/examples/compression.rs @@ -60,7 +60,8 @@ mod lz4 { use async_compression::tokio::{bufread::Lz4Decoder, write::Lz4Encoder}; use iroh::endpoint::VarInt; use iroh_blobs::util::{ - AsyncReadRecvStream, AsyncWriteSendStream, RecvStreamSpecific, SendStreamSpecific, + AsyncReadRecvStream, AsyncReadRecvStreamExtra, AsyncWriteSendStream, + AsyncWriteSendStreamExtra, }; use tokio::io::{AsyncRead, AsyncWrite, BufReader}; @@ -72,7 +73,7 @@ mod lz4 { } } - impl SendStreamSpecific for SendStream { + impl AsyncWriteSendStreamExtra for SendStream { fn inner(&mut self) -> &mut (impl AsyncWrite + Unpin + Send) { &mut self.0 } @@ -84,6 +85,10 @@ mod lz4 { async fn stopped(&mut self) -> io::Result> { Ok(self.0.get_mut().stopped().await?) } + + fn id(&self) -> u64 { + self.0.get_ref().id().index() + } } struct RecvStream(Lz4Decoder>); @@ -94,7 +99,7 @@ mod lz4 { } } - impl RecvStreamSpecific for RecvStream { + impl AsyncReadRecvStreamExtra for RecvStream { fn inner(&mut self) -> &mut (impl AsyncRead + Unpin + Send) { &mut self.0 } diff --git a/src/util.rs b/src/util.rs index 58ae91856..c0acfcaad 100644 --- a/src/util.rs +++ b/src/util.rs @@ -4,8 +4,8 @@ pub mod connection_pool; mod stream; pub(crate) mod temp_tag; pub use stream::{ - AsyncReadRecvStream, AsyncWriteSendStream, RecvStream, RecvStreamAsyncStreamReader, - RecvStreamSpecific, SendStream, SendStreamSpecific, + AsyncReadRecvStream, AsyncReadRecvStreamExtra, AsyncWriteSendStream, AsyncWriteSendStreamExtra, + RecvStream, RecvStreamAsyncStreamReader, SendStream, }; pub(crate) use stream::{RecvStreamExt, SendStreamExt}; diff --git a/src/util/stream.rs b/src/util/stream.rs index 9834a2075..ebdeef1be 100644 --- a/src/util/stream.rs +++ b/src/util/stream.rs @@ -170,13 +170,22 @@ impl SendStream for &mut W { #[derive(Debug)] pub struct AsyncReadRecvStream(R); +/// This is a helper trait to work with [`AsyncReadRecvStream`]. If you have an +/// `AsyncRead + Unpin + Send`, you can implement these additional methods and wrap the result +/// in an `AsyncReadRecvStream` to get a `RecvStream` that reads from the underlying `AsyncRead`. +pub trait AsyncReadRecvStreamExtra: Send { + fn inner(&mut self) -> &mut (impl AsyncRead + Unpin + Send); + fn stop(&mut self, code: VarInt) -> io::Result<()>; + fn id(&self) -> u64; +} + impl AsyncReadRecvStream { pub fn new(inner: R) -> Self { Self(inner) } } -impl RecvStream for AsyncReadRecvStream { +impl RecvStream for AsyncReadRecvStream { async fn recv_bytes(&mut self, len: usize) -> io::Result { let mut res = vec![0; len]; let mut n = 0; @@ -215,18 +224,6 @@ impl RecvStream for AsyncReadRecvStream { } } -pub trait RecvStreamSpecific: Send { - fn inner(&mut self) -> &mut (impl AsyncRead + Unpin + Send); - fn stop(&mut self, code: VarInt) -> io::Result<()>; - fn id(&self) -> u64; -} - -pub trait SendStreamSpecific: Send { - fn inner(&mut self) -> &mut (impl AsyncWrite + Unpin + Send); - fn reset(&mut self, code: VarInt) -> io::Result<()>; - fn stopped(&mut self) -> impl Future>> + Send; -} - impl RecvStream for Bytes { async fn recv_bytes(&mut self, len: usize) -> io::Result { let n = len.min(self.len()); @@ -267,19 +264,31 @@ impl RecvStream for Bytes { #[derive(Debug, Clone)] pub struct AsyncWriteSendStream(W); -impl AsyncWriteSendStream { +/// This is a helper trait to work with [`AsyncWriteSendStream`]. +/// +/// If you have an `AsyncWrite + Unpin + Send`, you can implement these additional +/// methods and wrap the result in an `AsyncWriteSendStream` to get a `SendStream` +/// that writes to the underlying `AsyncWrite`. +pub trait AsyncWriteSendStreamExtra: Send { + fn inner(&mut self) -> &mut (impl AsyncWrite + Unpin + Send); + fn reset(&mut self, code: VarInt) -> io::Result<()>; + fn stopped(&mut self) -> impl Future>> + Send; + fn id(&self) -> u64; +} + +impl AsyncWriteSendStream { pub fn new(inner: W) -> Self { Self(inner) } } -impl AsyncWriteSendStream { +impl AsyncWriteSendStream { pub fn into_inner(self) -> W { self.0 } } -impl SendStream for AsyncWriteSendStream { +impl SendStream for AsyncWriteSendStream { async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> { self.0.inner().write_all(&bytes).await } @@ -303,7 +312,7 @@ impl SendStream for AsyncWriteSendStream { } fn id(&self) -> u64 { - 0 + self.0.id() } } From a342ffa99f739ef5b0a060df7f797409dc222a5d Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Mon, 6 Oct 2025 11:06:05 +0300 Subject: [PATCH 35/35] Add documentation ot the ...Extra adapter traits. Also remove const generics and make the read and write fns a bit more "traditional". Makes usage a bit more inconvenient, but :shrug: --- src/api/blobs.rs | 7 ++++- src/get.rs | 3 ++- src/util/stream.rs | 67 +++++++++++++++++++++++++++------------------- 3 files changed, 47 insertions(+), 30 deletions(-) diff --git a/src/api/blobs.rs b/src/api/blobs.rs index 4a9c28b10..6e8bbc3c3 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -437,7 +437,12 @@ impl Blobs { ranges: ChunkRanges, mut reader: R, ) -> RequestResult { - let size = u64::from_le_bytes(reader.recv::<8>().await.map_err(super::Error::other)?); + let mut size = [0; 8]; + reader + .recv_exact(&mut size) + .await + .map_err(super::Error::other)?; + let size = u64::from_le_bytes(size); let Some(size) = NonZeroU64::new(size) else { return if hash == Hash::EMPTY { Ok(reader) diff --git a/src/get.rs b/src/get.rs index 234fe5382..d13092a85 100644 --- a/src/get.rs +++ b/src/get.rs @@ -535,7 +535,8 @@ pub mod fsm { impl AtBlobHeader { /// Read the size header, returning it and going into the `Content` state. pub async fn next(mut self) -> Result<(AtBlobContent, u64), AtBlobHeaderNextError> { - let size = self.reader.recv::<8>().await.map_err(|cause| { + let mut size = [0; 8]; + self.reader.recv_exact(&mut size).await.map_err(|cause| { if cause.kind() == io::ErrorKind::UnexpectedEof { at_blob_header_next_error::NotFoundSnafu.build() } else { diff --git a/src/util/stream.rs b/src/util/stream.rs index ebdeef1be..2816338b1 100644 --- a/src/util/stream.rs +++ b/src/util/stream.rs @@ -17,10 +17,7 @@ pub trait SendStream: Send { /// This method is not cancellation safe. Even if this does not resolve, some bytes may have been written when previously polled. fn send_bytes(&mut self, bytes: Bytes) -> impl Future> + Send; /// Send that sends a fixed sized buffer. - fn send( - &mut self, - buf: &[u8; L], - ) -> impl Future> + Send; + fn send(&mut self, buf: &[u8]) -> impl Future> + Send; /// Sync the stream. Not needed for iroh, but needed for intermediate buffered streams such as compression. fn sync(&mut self) -> impl Future> + Send; /// Reset the stream with the given error code. @@ -41,8 +38,8 @@ pub trait RecvStream: Send { /// /// Note that this is different from `recv_bytes`, which will return fewer bytes if the stream ends. fn recv_bytes_exact(&mut self, len: usize) -> impl Future> + Send; - /// Receive exactly `L` bytes from the stream, directly into a `[u8; L]`. - fn recv(&mut self) -> impl Future> + Send; + /// Receive exactly `target.len()` bytes from the stream. + fn recv_exact(&mut self, target: &mut [u8]) -> impl Future> + Send; /// Stop the stream with the given error code. fn stop(&mut self, code: VarInt) -> io::Result<()>; /// Get the stream id. @@ -54,7 +51,7 @@ impl SendStream for iroh::endpoint::SendStream { Ok(self.write_chunk(bytes).await?) } - async fn send(&mut self, buf: &[u8; L]) -> io::Result<()> { + async fn send(&mut self, buf: &[u8]) -> io::Result<()> { Ok(self.write_all(buf).await?) } @@ -100,14 +97,12 @@ impl RecvStream for iroh::endpoint::RecvStream { Ok(buf.into()) } - async fn recv(&mut self) -> io::Result<[u8; L]> { - let mut buf = [0; L]; - self.read_exact(&mut buf).await.map_err(|e| match e { + async fn recv_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.read_exact(buf).await.map_err(|e| match e { ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""), ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""), ReadExactError::ReadError(e) => e.into(), - })?; - Ok(buf) + }) } fn stop(&mut self, code: VarInt) -> io::Result<()> { @@ -128,8 +123,8 @@ impl RecvStream for &mut R { self.deref_mut().recv_bytes_exact(len).await } - async fn recv(&mut self) -> io::Result<[u8; L]> { - self.deref_mut().recv::().await + async fn recv_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.deref_mut().recv_exact(buf).await } fn stop(&mut self, code: VarInt) -> io::Result<()> { @@ -146,7 +141,7 @@ impl SendStream for &mut W { self.deref_mut().send_bytes(bytes).await } - async fn send(&mut self, buf: &[u8; L]) -> io::Result<()> { + async fn send(&mut self, buf: &[u8]) -> io::Result<()> { self.deref_mut().send(buf).await } @@ -174,8 +169,15 @@ pub struct AsyncReadRecvStream(R); /// `AsyncRead + Unpin + Send`, you can implement these additional methods and wrap the result /// in an `AsyncReadRecvStream` to get a `RecvStream` that reads from the underlying `AsyncRead`. pub trait AsyncReadRecvStreamExtra: Send { + /// Get a mutable reference to the inner `AsyncRead`. + /// + /// Getting a reference is easier than implementing all methods on `AsyncWrite` with forwarders to the inner instance. fn inner(&mut self) -> &mut (impl AsyncRead + Unpin + Send); + /// Stop the stream with the given error code. fn stop(&mut self, code: VarInt) -> io::Result<()>; + /// A local unique identifier for the stream. + /// + /// This allows distinguishing between streams, but once the stream is closed, the id may be reused. fn id(&self) -> u64; } @@ -209,10 +211,9 @@ impl RecvStream for AsyncReadRecvStream { Ok(res.into()) } - async fn recv(&mut self) -> io::Result<[u8; L]> { - let mut res = [0; L]; - self.0.inner().read_exact(&mut res).await?; - Ok(res) + async fn recv_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.0.inner().read_exact(buf).await?; + Ok(()) } fn stop(&mut self, code: VarInt) -> io::Result<()> { @@ -241,14 +242,13 @@ impl RecvStream for Bytes { Ok(res) } - async fn recv(&mut self) -> io::Result<[u8; L]> { - if self.len() < L { + async fn recv_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + if self.len() < buf.len() { return Err(io::ErrorKind::UnexpectedEof.into()); } - let mut res = [0; L]; - res.copy_from_slice(&self[..L]); - *self = self.slice(L..); - Ok(res) + buf.copy_from_slice(&self[..buf.len()]); + *self = self.slice(buf.len()..); + Ok(()) } fn stop(&mut self, _code: VarInt) -> io::Result<()> { @@ -270,9 +270,17 @@ pub struct AsyncWriteSendStream(W); /// methods and wrap the result in an `AsyncWriteSendStream` to get a `SendStream` /// that writes to the underlying `AsyncWrite`. pub trait AsyncWriteSendStreamExtra: Send { + /// Get a mutable reference to the inner `AsyncWrite`. + /// + /// Getting a reference is easier than implementing all methods on `AsyncWrite` with forwarders to the inner instance. fn inner(&mut self) -> &mut (impl AsyncWrite + Unpin + Send); + /// Reset the stream with the given error code. fn reset(&mut self, code: VarInt) -> io::Result<()>; + /// Wait for the stream to be stopped, returning the optional error code if it was. fn stopped(&mut self) -> impl Future>> + Send; + /// A local unique identifier for the stream. + /// + /// This allows distinguishing between streams, but once the stream is closed, the id may be reused. fn id(&self) -> u64; } @@ -293,7 +301,7 @@ impl SendStream for AsyncWriteSendStream { self.0.inner().write_all(&bytes).await } - async fn send(&mut self, buf: &[u8; L]) -> io::Result<()> { + async fn send(&mut self, buf: &[u8]) -> io::Result<()> { self.0.inner().write_all(buf).await } @@ -335,7 +343,9 @@ impl AsyncStreamReader for RecvStreamAsyncStreamReader { } async fn read(&mut self) -> io::Result<[u8; L]> { - self.0.recv::().await + let mut buf = [0; L]; + self.0.recv_exact(&mut buf).await?; + Ok(buf) } } @@ -352,7 +362,8 @@ pub(crate) trait RecvStreamExt: RecvStream { } async fn read_u8(&mut self) -> io::Result { - let buf = self.recv::<1>().await?; + let mut buf = [0; 1]; + self.recv_exact(&mut buf).await?; Ok(buf[0]) }