diff --git a/Cargo.lock b/Cargo.lock index 4068354f7..625f30b7b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -156,6 +156,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-compression" +version = "0.4.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "977eb15ea9efd848bb8a4a1a2500347ed7f0bf794edf0dc3ddcf439f43d36b23" +dependencies = [ + "compression-codecs", + "compression-core", + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "async-trait" version = "0.1.88" @@ -374,6 +387,8 @@ version = "1.2.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -508,6 +523,24 @@ dependencies = [ "memchr", ] +[[package]] +name = "compression-codecs" +version = "0.4.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "485abf41ac0c8047c07c87c72c8fb3eb5197f6e9d7ded615dfd1a00ae00a0f64" +dependencies = [ + "compression-core", + "lz4", + "zstd", + "zstd-safe", +] + +[[package]] +name = "compression-core" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e47641d3deaf41fb1538ac1f54735925e275eaf3bf4d55c81b137fba797e5cbb" + [[package]] name = "const-oid" version = "0.9.6" @@ -1741,6 +1774,7 @@ version = "0.93.0" dependencies = [ "anyhow", "arrayvec", + "async-compression", "atomic_refcell", "bao-tree", "bytes", @@ -2008,6 +2042,16 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +[[package]] +name = "jobserver" +version = "0.1.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" +dependencies = [ + "getrandom 0.3.3", + "libc", +] + [[package]] name = "js-sys" version = "0.3.77" @@ -2092,13 +2136,32 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" +[[package]] +name = "lz4" +version = "1.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4" +dependencies = [ + "lz4-sys", +] + +[[package]] +name = "lz4-sys" +version = "1.11.1+lz4-1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "matchers" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" dependencies = [ - "regex-automata 0.1.10", + "regex-automata", ] [[package]] @@ -2385,12 +2448,11 @@ dependencies = [ [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "d4a28e057d01f97e61255210fcff094d74ed0466038633e95017f5beb68e4399" dependencies = [ - "overload", - "winapi", + "windows-sys 0.52.0", ] [[package]] @@ -2467,12 +2529,6 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "parking" version = "2.2.1" @@ -2660,6 +2716,12 @@ dependencies = [ "spki", ] +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "pnet_base" version = "0.34.0" @@ -2907,7 +2969,7 @@ dependencies = [ "rand 0.9.2", "rand_chacha 0.9.0", "rand_xorshift", - "regex-syntax 0.8.5", + "regex-syntax", "rusty-fork", "tempfile", "unarray", @@ -3151,17 +3213,8 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", + "regex-automata", + "regex-syntax", ] [[package]] @@ -3172,7 +3225,7 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.5", + "regex-syntax", ] [[package]] @@ -3181,12 +3234,6 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" -[[package]] -name = "regex-syntax" -version = "0.6.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - [[package]] name = "regex-syntax" version = "0.8.5" @@ -4245,14 +4292,14 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "matchers", "nu-ansi-term", "once_cell", - "regex", + "regex-automata", "sharded-slab", "smallvec", "thread_local", @@ -5219,3 +5266,31 @@ dependencies = [ "quote", "syn 2.0.104", ] + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.15+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index bcd5f42d0..a40b735bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ iroh-base = "0.91.1" reflink-copy = "0.1.24" irpc = { version = "0.7.0", features = ["rpc", "quinn_endpoint_setup", "spans", "stream", "derive"], default-features = false } iroh-metrics = { version = "0.35" } +async-compression = { version = "0.4.30", features = ["lz4", "tokio"] } [dev-dependencies] clap = { version = "4.5.31", features = ["derive"] } @@ -60,6 +61,7 @@ tracing-test = "0.2.5" walkdir = "2.5.0" atomic_refcell = "0.1.13" iroh = { version = "0.91.1", features = ["discovery-local-network"]} +async-compression = { version = "0.4.30", features = ["zstd", "tokio"] } [features] hide-proto-docs = [] @@ -68,4 +70,4 @@ default = ["hide-proto-docs"] [patch.crates-io] iroh = { git = "https://github.com/n0-computer/iroh", branch = "main" } -iroh-base = { git = "https://github.com/n0-computer/iroh", branch = "main" } \ No newline at end of file +iroh-base = { git = "https://github.com/n0-computer/iroh", branch = "main" } diff --git a/examples/compression.rs b/examples/compression.rs new file mode 100644 index 000000000..322277ea9 --- /dev/null +++ b/examples/compression.rs @@ -0,0 +1,194 @@ +/// Example how to limit blob requests by hash and node id, and to add +/// throttling or limiting the maximum number of connections. +/// +/// Limiting is done via a fn that returns an EventSender and internally +/// makes liberal use of spawn to spawn background tasks. +/// +/// This is fine, since the tasks will terminate as soon as the [BlobsProtocol] +/// instance holding the [EventSender] will be dropped. But for production +/// grade code you might nevertheless put the tasks into a [tokio::task::JoinSet] or +/// [n0_future::FuturesUnordered]. +mod common; +use std::{path::PathBuf, time::Instant}; + +use anyhow::Result; +use async_compression::tokio::{bufread::Lz4Decoder, write::Lz4Encoder}; +use bao_tree::blake3; +use clap::Parser; +use common::setup_logging; +use iroh::protocol::ProtocolHandler; +use iroh_blobs::{ + api::Store, + get::fsm::{AtConnected, ConnectedNext, EndBlobNext}, + protocol::{ChunkRangesSeq, GetRequest, Request}, + provider::{ + events::{ClientConnected, EventSender, HasErrorCode}, + handle_get, ErrorHandler, StreamPair, + }, + store::mem::MemStore, + ticket::BlobTicket, +}; +use iroh_io::{TokioStreamReader, TokioStreamWriter}; +use tokio::io::BufReader; +use tracing::debug; + +use crate::common::get_or_generate_secret_key; + +#[derive(Debug, Parser)] +#[command(version, about)] +pub enum Args { + /// Limit requests by node id + Provide { + /// Path for files to add. + path: PathBuf, + }, + /// Get a blob. Just for completeness sake. + Get { + /// Ticket for the blob to download + ticket: BlobTicket, + /// Path to save the blob to + #[clap(long)] + target: Option, + }, +} + +type CompressedWriter = + TokioStreamWriter>; +type CompressedReader = TokioStreamReader< + async_compression::tokio::bufread::Lz4Decoder>, +>; + +#[derive(Debug, Clone)] +struct CompressedBlobsProtocol { + store: Store, + events: EventSender, +} + +impl CompressedBlobsProtocol { + fn new(store: &Store, events: EventSender) -> Self { + Self { + store: store.clone(), + events, + } + } +} + +struct CompressedErrorHandler; + +impl ErrorHandler for CompressedErrorHandler { + type W = CompressedWriter; + + type R = CompressedReader; + + async fn stop(reader: &mut Self::R, code: quinn::VarInt) { + reader.0.get_mut().get_mut().stop(code).ok(); + } + + async fn reset(writer: &mut Self::W, code: quinn::VarInt) { + writer.0.get_mut().reset(code).ok(); + } +} + +impl ProtocolHandler for CompressedBlobsProtocol { + async fn accept( + &self, + connection: iroh::endpoint::Connection, + ) -> std::result::Result<(), iroh::protocol::AcceptError> { + let connection_id = connection.stable_id() as u64; + let node_id = connection.remote_node_id()?; + if let Err(cause) = self + .events + .client_connected(|| ClientConnected { + connection_id, + node_id, + }) + .await + { + connection.close(cause.code(), cause.reason()); + debug!("closing connection: {cause}"); + return Ok(()); + } + while let Ok((send, recv)) = connection.accept_bi().await { + println!("Accepted new stream"); + let stream_id = send.id().index(); + let send = TokioStreamWriter(Lz4Encoder::new(send)); + let recv = TokioStreamReader(Lz4Decoder::new(BufReader::new(recv))); + let store = self.store.clone(); + let mut pair = + StreamPair::new(connection_id, stream_id, recv, send, self.events.clone()); + tokio::spawn(async move { + println!("Handling stream"); + let request = pair.read_request().await?; + println!("Received request: {request:?}"); + if let Request::Get(request) = request { + handle_get::(pair, store, request).await?; + } + anyhow::Ok(()) + }); + } + Ok(()) + } +} + +const ALPN: &[u8] = b"iroh-blobs-compressed/0.1.0"; + +#[tokio::main] +async fn main() -> Result<()> { + setup_logging(); + let args = Args::parse(); + let secret = get_or_generate_secret_key()?; + let endpoint = iroh::Endpoint::builder() + .secret_key(secret) + .discovery_n0() + .bind() + .await?; + match args { + Args::Provide { path } => { + let store = MemStore::new(); + let tag = store.add_path(path).await?; + let blobs = CompressedBlobsProtocol::new(&store, EventSender::DEFAULT); + let router = iroh::protocol::Router::builder(endpoint.clone()) + .accept(ALPN, blobs) + .spawn(); + let ticket = BlobTicket::new(endpoint.node_id().into(), tag.hash, tag.format); + println!("Serving blob with hash {}", tag.hash); + println!("Ticket: {ticket}"); + println!("Node is running. Press Ctrl-C to exit."); + tokio::signal::ctrl_c().await?; + println!("Shutting down."); + router.shutdown().await?; + } + Args::Get { ticket, target } => { + let conn = endpoint.connect(ticket.node_addr().clone(), ALPN).await?; + let (send, recv) = conn.open_bi().await?; + let send = TokioStreamWriter(Lz4Encoder::new(send)); + let recv = TokioStreamReader(Lz4Decoder::new(BufReader::new(recv))); + let request = GetRequest { + hash: ticket.hash(), + ranges: ChunkRangesSeq::root(), + }; + let connected = + AtConnected::new(Instant::now(), recv, send, request, Default::default()); + let ConnectedNext::StartRoot(start) = connected.next().await? else { + unreachable!("expected start root"); + }; + let (end, data) = start.next().concatenate_into_vec().await?; + let EndBlobNext::Closing(closing) = end.next() else { + unreachable!("expected closing"); + }; + let stats = closing.next().await?; + if let Some(target) = target { + tokio::fs::write(&target, &data).await?; + println!( + "Wrote {} bytes to {}", + stats.payload_bytes_read, + target.display() + ); + } else { + let hash = blake3::hash(&data); + println!("Hash: {hash}"); + } + } + } + Ok(()) +} diff --git a/examples/limit.rs b/examples/limit.rs new file mode 100644 index 000000000..e72f9be59 --- /dev/null +++ b/examples/limit.rs @@ -0,0 +1,371 @@ +/// Example how to limit blob requests by hash and node id, and to add +/// throttling or limiting the maximum number of connections. +/// +/// Limiting is done via a fn that returns an EventSender and internally +/// makes liberal use of spawn to spawn background tasks. +/// +/// This is fine, since the tasks will terminate as soon as the [BlobsProtocol] +/// instance holding the [EventSender] will be dropped. But for production +/// grade code you might nevertheless put the tasks into a [tokio::task::JoinSet] or +/// [n0_future::FuturesUnordered]. +mod common; +use std::{ + collections::{HashMap, HashSet}, + path::PathBuf, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +use anyhow::Result; +use clap::Parser; +use common::setup_logging; +use iroh::{protocol::Router, NodeAddr, NodeId, SecretKey, Watcher}; +use iroh_blobs::{ + provider::events::{ + AbortReason, ConnectMode, EventMask, EventSender, ProviderMessage, RequestMode, + ThrottleMode, + }, + store::mem::MemStore, + ticket::BlobTicket, + BlobFormat, BlobsProtocol, Hash, +}; +use rand::thread_rng; + +use crate::common::get_or_generate_secret_key; + +#[derive(Debug, Parser)] +#[command(version, about)] +pub enum Args { + /// Limit requests by node id + ByNodeId { + /// Path for files to add. + paths: Vec, + #[clap(long("allow"))] + /// Nodes that are allowed to download content. + allowed_nodes: Vec, + /// Number of secrets to generate for allowed node ids. + #[clap(long, default_value_t = 1)] + secrets: usize, + }, + /// Limit requests by hash, only first hash is allowed + ByHash { + /// Path for files to add. + paths: Vec, + }, + /// Throttle requests + Throttle { + /// Path for files to add. + paths: Vec, + /// Delay in milliseconds after sending a chunk group of 16 KiB. + #[clap(long, default_value = "100")] + delay_ms: u64, + }, + /// Limit maximum number of connections. + MaxConnections { + /// Path for files to add. + paths: Vec, + /// Maximum number of concurrent get requests. + #[clap(long, default_value = "1")] + max_connections: usize, + }, + /// Get a blob. Just for completeness sake. + Get { + /// Ticket for the blob to download + ticket: BlobTicket, + }, +} + +fn limit_by_node_id(allowed_nodes: HashSet) -> EventSender { + let mask = EventMask { + // We want a request for each incoming connection so we can accept + // or reject them. We don't need any other events. + connected: ConnectMode::Request, + ..EventMask::DEFAULT + }; + let (tx, mut rx) = EventSender::channel(32, mask); + n0_future::task::spawn(async move { + while let Some(msg) = rx.recv().await { + if let ProviderMessage::ClientConnected(msg) = msg { + let node_id = msg.node_id; + let res = if allowed_nodes.contains(&node_id) { + println!("Client connected: {node_id}"); + Ok(()) + } else { + println!("Client rejected: {node_id}"); + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); + } + } + }); + tx +} + +fn limit_by_hash(allowed_hashes: HashSet) -> EventSender { + let mask = EventMask { + // We want to get a request for each get request that we can answer + // with OK or not OK depending on the hash. We do not want detailed + // events once it has been decided to handle a request. + get: RequestMode::Request, + ..EventMask::DEFAULT + }; + let (tx, mut rx) = EventSender::channel(32, mask); + n0_future::task::spawn(async move { + while let Some(msg) = rx.recv().await { + if let ProviderMessage::GetRequestReceived(msg) = msg { + let res = if !msg.request.ranges.is_blob() { + println!("HashSeq request not allowed"); + Err(AbortReason::Permission) + } else if !allowed_hashes.contains(&msg.request.hash) { + println!("Request for hash {} not allowed", msg.request.hash); + Err(AbortReason::Permission) + } else { + println!("Request for hash {} allowed", msg.request.hash); + Ok(()) + }; + msg.tx.send(res).await.ok(); + } + } + }); + tx +} + +fn throttle(delay_ms: u64) -> EventSender { + let mask = EventMask { + // We want to get requests for each sent user data blob, so we can add a delay. + // Other than that, we don't need any events. + throttle: ThrottleMode::Throttle, + ..EventMask::DEFAULT + }; + let (tx, mut rx) = EventSender::channel(32, mask); + n0_future::task::spawn(async move { + while let Some(msg) = rx.recv().await { + if let ProviderMessage::Throttle(msg) = msg { + n0_future::task::spawn(async move { + println!( + "Throttling {} {}, {}ms", + msg.connection_id, msg.request_id, delay_ms + ); + // we could compute the delay from the size of the data to have a fixed rate. + // but the size is almost always 16 KiB (16 chunks). + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + msg.tx.send(Ok(())).await.ok(); + }); + } + } + }); + tx +} + +fn limit_max_connections(max_connections: usize) -> EventSender { + #[derive(Default, Debug, Clone)] + struct ConnectionCounter(Arc<(AtomicUsize, usize)>); + + impl ConnectionCounter { + fn new(max: usize) -> Self { + Self(Arc::new((Default::default(), max))) + } + + fn inc(&self) -> Result { + let (c, max) = &*self.0; + c.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| { + if n >= *max { + None + } else { + Some(n + 1) + } + }) + } + + fn dec(&self) { + let (c, _) = &*self.0; + c.fetch_sub(1, Ordering::SeqCst); + } + } + + let mask = EventMask { + // For each get request, we want to get a request so we can decide + // based on the current connection count if we want to accept or reject. + // We also want detailed logging of events for the get request, so we can + // detect when the request is finished one way or another. + get: RequestMode::RequestLog, + ..EventMask::DEFAULT + }; + let (tx, mut rx) = EventSender::channel(32, mask); + n0_future::task::spawn(async move { + let requests = ConnectionCounter::new(max_connections); + while let Some(msg) = rx.recv().await { + if let ProviderMessage::GetRequestReceived(mut msg) = msg { + let connection_id = msg.connection_id; + let request_id = msg.request_id; + let res = requests.inc(); + match res { + Ok(n) => { + println!("Accepting request {n}, id ({connection_id},{request_id})"); + msg.tx.send(Ok(())).await.ok(); + } + Err(_) => { + println!( + "Connection limit of {max_connections} exceeded, rejecting request" + ); + msg.tx.send(Err(AbortReason::RateLimited)).await.ok(); + continue; + } + } + let requests = requests.clone(); + n0_future::task::spawn(async move { + // just drain the per request events + // + // Note that we have requested updates for the request, now we also need to process them + // otherwise the request will be aborted! + while let Ok(Some(_)) = msg.rx.recv().await {} + println!("Stopping request, id ({connection_id},{request_id})"); + requests.dec(); + }); + } + } + }); + tx +} + +#[tokio::main] +async fn main() -> Result<()> { + setup_logging(); + let args = Args::parse(); + let secret = get_or_generate_secret_key()?; + let endpoint = iroh::Endpoint::builder() + .secret_key(secret) + .discovery_n0() + .bind() + .await?; + match args { + Args::Get { ticket } => { + let connection = endpoint + .connect(ticket.node_addr().clone(), iroh_blobs::ALPN) + .await?; + let (data, stats) = iroh_blobs::get::request::get_blob(connection, ticket.hash()) + .bytes_and_stats() + .await?; + println!("Downloaded {} bytes", data.len()); + println!("Stats: {stats:?}"); + } + Args::ByNodeId { + paths, + allowed_nodes, + secrets, + } => { + let mut allowed_nodes = allowed_nodes.into_iter().collect::>(); + if secrets > 0 { + println!("Generating {secrets} new secret keys for allowed nodes:"); + let mut rand = thread_rng(); + for _ in 0..secrets { + let secret = SecretKey::generate(&mut rand); + let public = secret.public(); + allowed_nodes.insert(public); + println!("IROH_SECRET={}", hex::encode(secret.to_bytes())); + } + } + + let store = MemStore::new(); + let hashes = add_paths(&store, paths).await?; + let events = limit_by_node_id(allowed_nodes.clone()); + let (router, addr) = setup(store, events).await?; + + for (path, hash) in hashes { + let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); + println!("{}: {ticket}", path.display()); + } + println!(); + println!("Node id: {}\n", router.endpoint().node_id()); + for id in &allowed_nodes { + println!("Allowed node: {id}"); + } + + tokio::signal::ctrl_c().await?; + router.shutdown().await?; + } + Args::ByHash { paths } => { + let store = MemStore::new(); + + let mut hashes = HashMap::new(); + let mut allowed_hashes = HashSet::new(); + for (i, path) in paths.into_iter().enumerate() { + let tag = store.add_path(&path).await?; + hashes.insert(path, tag.hash); + if i == 0 { + allowed_hashes.insert(tag.hash); + } + } + + let events = limit_by_hash(allowed_hashes.clone()); + let (router, addr) = setup(store, events).await?; + + for (path, hash) in hashes.iter() { + let ticket = BlobTicket::new(addr.clone(), *hash, BlobFormat::Raw); + let permitted = if allowed_hashes.contains(hash) { + "allowed" + } else { + "forbidden" + }; + println!("{}: {ticket} ({permitted})", path.display()); + } + tokio::signal::ctrl_c().await?; + router.shutdown().await?; + } + Args::Throttle { paths, delay_ms } => { + let store = MemStore::new(); + let hashes = add_paths(&store, paths).await?; + let events = throttle(delay_ms); + let (router, addr) = setup(store, events).await?; + for (path, hash) in hashes { + let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); + println!("{}: {ticket}", path.display()); + } + tokio::signal::ctrl_c().await?; + router.shutdown().await?; + } + Args::MaxConnections { + paths, + max_connections, + } => { + let store = MemStore::new(); + let hashes = add_paths(&store, paths).await?; + let events = limit_max_connections(max_connections); + let (router, addr) = setup(store, events).await?; + for (path, hash) in hashes { + let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw); + println!("{}: {ticket}", path.display()); + } + tokio::signal::ctrl_c().await?; + router.shutdown().await?; + } + } + Ok(()) +} + +async fn add_paths(store: &MemStore, paths: Vec) -> Result> { + let mut hashes = HashMap::new(); + for path in paths { + let tag = store.add_path(&path).await?; + hashes.insert(path, tag.hash); + } + Ok(hashes) +} + +async fn setup(store: MemStore, events: EventSender) -> Result<(Router, NodeAddr)> { + let secret = get_or_generate_secret_key()?; + let endpoint = iroh::Endpoint::builder() + .discovery_n0() + .secret_key(secret) + .bind() + .await?; + let _ = endpoint.home_relay().initialized().await; + let addr = endpoint.node_addr().initialized().await; + let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events)); + let router = Router::builder(endpoint) + .accept(iroh_blobs::ALPN, blobs) + .spawn(); + Ok((router, addr)) +} diff --git a/examples/random_store.rs b/examples/random_store.rs index ffdd9b826..d3f9a0fc4 100644 --- a/examples/random_store.rs +++ b/examples/random_store.rs @@ -6,14 +6,15 @@ use iroh::{SecretKey, Watcher}; use iroh_base::ticket::NodeTicket; use iroh_blobs::{ api::downloader::Shuffled, - provider::Event, + provider::events::{AbortReason, EventMask, EventSender, ProviderMessage}, store::fs::FsStore, test::{add_hash_sequences, create_random_blobs}, HashAndFormat, }; +use irpc::RpcMessage; use n0_future::StreamExt; use rand::{rngs::StdRng, Rng, SeedableRng}; -use tokio::{signal::ctrl_c, sync::mpsc}; +use tokio::signal::ctrl_c; use tracing::info; #[derive(Parser, Debug)] @@ -100,77 +101,77 @@ pub fn get_or_generate_secret_key() -> Result { } } -pub fn dump_provider_events( - allow_push: bool, -) -> ( - tokio::task::JoinHandle<()>, - mpsc::Sender, -) { - let (tx, mut rx) = mpsc::channel(100); +pub fn dump_provider_events(allow_push: bool) -> (tokio::task::JoinHandle<()>, EventSender) { + let (tx, mut rx) = EventSender::channel(100, EventMask::ALL_READONLY); + fn dump_updates(mut rx: irpc::channel::mpsc::Receiver) { + tokio::spawn(async move { + while let Ok(Some(update)) = rx.recv().await { + println!("{update:?}"); + } + }); + } let dump_task = tokio::spawn(async move { while let Some(event) = rx.recv().await { match event { - Event::ClientConnected { - node_id, - connection_id, - permitted, - } => { - permitted.send(true).await.ok(); - println!("Client connected: {node_id} {connection_id}"); + ProviderMessage::ClientConnected(msg) => { + println!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + } + ProviderMessage::ClientConnectedNotify(msg) => { + println!("{:?}", msg.inner); + } + ProviderMessage::ConnectionClosed(msg) => { + println!("{:?}", msg.inner); } - Event::GetRequestReceived { - connection_id, - request_id, - hash, - ranges, - } => { - println!( - "Get request received: {connection_id} {request_id} {hash} {ranges:?}" - ); + ProviderMessage::GetRequestReceived(msg) => { + println!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + dump_updates(msg.rx); } - Event::TransferCompleted { - connection_id, - request_id, - stats, - } => { - println!("Transfer completed: {connection_id} {request_id} {stats:?}"); + ProviderMessage::GetRequestReceivedNotify(msg) => { + println!("{:?}", msg.inner); + dump_updates(msg.rx); } - Event::TransferAborted { - connection_id, - request_id, - stats, - } => { - println!("Transfer aborted: {connection_id} {request_id} {stats:?}"); + ProviderMessage::GetManyRequestReceived(msg) => { + println!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + dump_updates(msg.rx); } - Event::TransferProgress { - connection_id, - request_id, - index, - end_offset, - } => { - info!("Transfer progress: {connection_id} {request_id} {index} {end_offset}"); + ProviderMessage::GetManyRequestReceivedNotify(msg) => { + println!("{:?}", msg.inner); + dump_updates(msg.rx); } - Event::PushRequestReceived { - connection_id, - request_id, - hash, - ranges, - permitted, - } => { - if allow_push { - permitted.send(true).await.ok(); - println!( - "Push request received: {connection_id} {request_id} {hash} {ranges:?}" - ); + ProviderMessage::PushRequestReceived(msg) => { + println!("{:?}", msg.inner); + let res = if allow_push { + Ok(()) } else { - permitted.send(false).await.ok(); - println!( - "Push request denied: {connection_id} {request_id} {hash} {ranges:?}" - ); - } + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); + dump_updates(msg.rx); + } + ProviderMessage::PushRequestReceivedNotify(msg) => { + println!("{:?}", msg.inner); + dump_updates(msg.rx); + } + ProviderMessage::ObserveRequestReceived(msg) => { + println!("{:?}", msg.inner); + let res = if allow_push { + Ok(()) + } else { + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); + dump_updates(msg.rx); + } + ProviderMessage::ObserveRequestReceivedNotify(msg) => { + println!("{:?}", msg.inner); + dump_updates(msg.rx); } - _ => { - info!("Received event: {:?}", event); + ProviderMessage::Throttle(msg) => { + println!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); } } } diff --git a/proptest-regressions/store/fs/util/entity_manager.txt b/proptest-regressions/store/fs/util/entity_manager.txt new file mode 100644 index 000000000..94b6aa63c --- /dev/null +++ b/proptest-regressions/store/fs/util/entity_manager.txt @@ -0,0 +1,7 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 0f2ebc49ab2f84e112f08407bb94654fbcb1f19050a4a8a6196383557696438a # shrinks to input = _TestCountersManagerProptestFsArgs { entries: [(15313427648878534792, 264348813928009031854006459208395772047), (1642534478798447378, 15989109311941500072752977306696275871), (8755041673862065815, 172763711808688570294350362332402629716), (4993597758667891804, 114145440157220458287429360639759690928), (15031383154962489250, 63217081714858286463391060323168548783), (17668469631267503333, 11878544422669770587175118199598836678), (10507570291819955314, 126584081645379643144412921692654648228), (3979008599365278329, 283717221942996985486273080647433218905), (8316838360288996639, 334043288511621783152802090833905919408), (15673798930962474157, 77551315511802713260542200115027244708), (12058791254144360414, 56638044274259821850511200885092637649), (8191628769638031337, 314181956273420400069887649110740549194), (6290369460137232066, 255779791286732775990301011955519176773), (11919824746661852269, 319400891587146831511371932480749645441), (12491631698789073154, 271279849791970841069522263758329847554), (53891048909263304, 12061234604041487609497959407391945555), (9486366498650667097, 311383186592430597410801882015456718030), (15696332331789302593, 306911490707714340526403119780178604150), (8699088947997536151, 312272624973367009520183311568498652066), (1144772544750976199, 200591877747619565555594857038887015), (5907208586200645081, 299942008952473970881666769409865744975), (3384528743842518913, 26230956866762934113564101494944411446), (13877357832690956494, 229457597607752760006918374695475345151), (2965687966026226090, 306489188264741716662410004273408761623), (13624286905717143613, 232801392956394366686194314010536008033), (3622356130274722018, 162030840677521022192355139208505458492), (17807768575470996347, 264107246314713159406963697924105744409), (5103434150074147746, 331686166459964582006209321975587627262), (5962771466034321974, 300961804728115777587520888809168362574), (2930645694242691907, 127752709774252686733969795258447263979), (16197574560597474644, 245410120683069493317132088266217906749), (12478835478062365617, 103838791113879912161511798836229961653), (5503595333662805357, 92368472243854403026472376408708548349), (18122734335129614364, 288955542597300001147753560885976966029), (12688080215989274550, 85237436689682348751672119832134138932), (4148468277722853958, 297778117327421209654837771300216669574), (8749445804640085302, 79595866493078234154562014325793780126), (12442730869682574563, 196176786402808588883611974143577417817), (6110644747049355904, 26592587989877021920275416199052685135), (5851164380497779369, 158876888501825038083692899057819261957), (9497384378514985275, 15279835675313542048650599472403150097), (10661092311826161857, 250089949043892591422587928179995867509), (10046856000675345423, 231369150063141386398059701278066296663)] } diff --git a/src/api.rs b/src/api.rs index a2a34a2db..dc38498d3 100644 --- a/src/api.rs +++ b/src/api.rs @@ -30,7 +30,7 @@ pub mod downloader; pub mod proto; pub mod remote; pub mod tags; -use crate::api::proto::WaitIdleRequest; +use crate::{api::proto::WaitIdleRequest, provider::events::ProgressError}; pub use crate::{store::util::Tag, util::temp_tag::TempTag}; pub(crate) type ApiClient = irpc::Client; @@ -97,6 +97,8 @@ pub enum ExportBaoError { ExportBaoIo { source: io::Error }, #[snafu(display("encode error: {source}"))] ExportBaoInner { source: bao_tree::io::EncodeError }, + #[snafu(display("client error: {source}"))] + Progress { source: ProgressError }, } impl From for Error { @@ -107,6 +109,7 @@ impl From for Error { ExportBaoError::Request { source, .. } => Self::Io(source.into()), ExportBaoError::ExportBaoIo { source, .. } => Self::Io(source), ExportBaoError::ExportBaoInner { source, .. } => Self::Io(source.into()), + ExportBaoError::Progress { source, .. } => Self::Io(source.into()), } } } @@ -152,6 +155,12 @@ impl From for ExportBaoError { } } +impl From for ExportBaoError { + fn from(value: ProgressError) -> Self { + ProgressSnafu.into_error(value) + } +} + pub type ExportBaoResult = std::result::Result; #[derive(Debug, derive_more::Display, derive_more::From, Serialize, Deserialize)] diff --git a/src/api/blobs.rs b/src/api/blobs.rs index d0b948598..830dd2042 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -23,14 +23,12 @@ use bao_tree::{ }; use bytes::Bytes; use genawaiter::sync::Gen; -use iroh_io::{AsyncStreamReader, TokioStreamReader}; +use iroh_io::{AsyncStreamReader, AsyncStreamWriter}; use irpc::channel::{mpsc, oneshot}; use n0_future::{future, stream, Stream, StreamExt}; -use quinn::SendStream; use range_collections::{range_set::RangeSetRange, RangeSet2}; use ref_cast::RefCast; use serde::{Deserialize, Serialize}; -use tokio::io::AsyncWriteExt; use tracing::trace; mod reader; pub use reader::BlobReader; @@ -57,7 +55,7 @@ use super::{ }; use crate::{ api::proto::{BatchRequest, ImportByteStreamUpdate}, - provider::StreamContext, + provider::events::ClientResult, store::IROH_BLOCK_SIZE, util::temp_tag::TempTag, BlobFormat, Hash, HashAndFormat, @@ -431,7 +429,7 @@ impl Blobs { } #[cfg_attr(feature = "hide-proto-docs", doc(hidden))] - async fn import_bao_reader( + pub async fn import_bao_reader( &self, hash: Hash, ranges: ChunkRanges, @@ -468,18 +466,6 @@ impl Blobs { Ok(reader?) } - #[cfg_attr(feature = "hide-proto-docs", doc(hidden))] - pub async fn import_bao_quinn( - &self, - hash: Hash, - ranges: ChunkRanges, - stream: &mut iroh::endpoint::RecvStream, - ) -> RequestResult<()> { - let reader = TokioStreamReader::new(stream); - self.import_bao_reader(hash, ranges, reader).await?; - Ok(()) - } - #[cfg_attr(feature = "hide-proto-docs", doc(hidden))] pub async fn import_bao_bytes( &self, @@ -1058,24 +1044,21 @@ impl ExportBaoProgress { Ok(data) } - pub async fn write_quinn(self, target: &mut quinn::SendStream) -> super::ExportBaoResult<()> { + pub async fn write(self, target: &mut W) -> super::ExportBaoResult<()> { let mut rx = self.inner.await?; while let Some(item) = rx.recv().await? { match item { EncodedItem::Size(size) => { - target.write_u64_le(size).await?; + target.write(&size.to_le_bytes()).await?; } EncodedItem::Parent(parent) => { let mut data = vec![0u8; 64]; data[..32].copy_from_slice(parent.pair.0.as_bytes()); data[32..].copy_from_slice(parent.pair.1.as_bytes()); - target.write_all(&data).await.map_err(io::Error::from)?; + target.write(&data).await?; } EncodedItem::Leaf(leaf) => { - target - .write_chunk(leaf.data) - .await - .map_err(io::Error::from)?; + target.write_bytes(leaf.data).await?; } EncodedItem::Done => break, EncodedItem::Error(cause) => return Err(cause.into()), @@ -1085,9 +1068,9 @@ impl ExportBaoProgress { } /// Write quinn variant that also feeds a progress writer. - pub(crate) async fn write_quinn_with_progress( + pub(crate) async fn write_with_progress( self, - writer: &mut SendStream, + writer: &mut W, progress: &mut impl WriteProgress, hash: &Hash, index: u64, @@ -1097,23 +1080,22 @@ impl ExportBaoProgress { match item { EncodedItem::Size(size) => { progress.send_transfer_started(index, hash, size).await; - writer.write_u64_le(size).await?; + writer.write(&size.to_le_bytes()).await?; progress.log_other_write(8); } EncodedItem::Parent(parent) => { let mut data = vec![0u8; 64]; data[..32].copy_from_slice(parent.pair.0.as_bytes()); data[32..].copy_from_slice(parent.pair.1.as_bytes()); - writer.write_all(&data).await.map_err(io::Error::from)?; + writer.write(&data).await?; progress.log_other_write(64); } EncodedItem::Leaf(leaf) => { let len = leaf.data.len(); - writer - .write_chunk(leaf.data) - .await - .map_err(io::Error::from)?; - progress.notify_payload_write(index, leaf.offset, len).await; + writer.write_bytes(leaf.data).await?; + progress + .notify_payload_write(index, leaf.offset, len) + .await?; } EncodedItem::Done => break, EncodedItem::Error(cause) => return Err(cause.into()), @@ -1159,7 +1141,7 @@ impl ExportBaoProgress { pub(crate) trait WriteProgress { /// Notify the progress writer that a payload write has happened. - async fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize); + async fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize) -> ClientResult; /// Log a write of some other data. fn log_other_write(&mut self, len: usize); @@ -1167,17 +1149,3 @@ pub(crate) trait WriteProgress { /// Notify the progress writer that a transfer has started. async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64); } - -impl WriteProgress for StreamContext { - async fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize) { - StreamContext::notify_payload_write(self, index, offset, len); - } - - fn log_other_write(&mut self, len: usize) { - StreamContext::log_other_write(self, len); - } - - async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) { - StreamContext::send_transfer_started(self, index, hash, size).await - } -} diff --git a/src/api/downloader.rs b/src/api/downloader.rs index a2abbd7ea..8ac188000 100644 --- a/src/api/downloader.rs +++ b/src/api/downloader.rs @@ -3,7 +3,6 @@ use std::{ collections::{HashMap, HashSet}, fmt::Debug, future::{Future, IntoFuture}, - io, sync::Arc, }; @@ -113,7 +112,7 @@ async fn handle_download_impl( SplitStrategy::Split => handle_download_split_impl(store, pool, request, tx).await?, SplitStrategy::None => match request.request { FiniteRequest::Get(get) => { - let sink = IrpcSenderRefSink(tx).with_map_err(io::Error::other); + let sink = IrpcSenderRefSink(tx); execute_get(&pool, Arc::new(get), &request.providers, &store, sink).await?; } FiniteRequest::GetMany(_) => { @@ -143,9 +142,7 @@ async fn handle_download_split_impl( let hash = request.hash; let (tx, rx) = tokio::sync::mpsc::channel::<(usize, DownloadProgessItem)>(16); progress_tx.send(rx).await.ok(); - let sink = TokioMpscSenderSink(tx) - .with_map_err(io::Error::other) - .with_map(move |x| (id, x)); + let sink = TokioMpscSenderSink(tx).with_map(move |x| (id, x)); let res = execute_get(&pool, Arc::new(request), &providers, &store, sink).await; (hash, res) } @@ -375,7 +372,7 @@ async fn split_request<'a>( providers: &Arc, pool: &ConnectionPool, store: &Store, - progress: impl Sink, + progress: impl Sink, ) -> anyhow::Result + Send + 'a>> { Ok(match request { FiniteRequest::Get(req) => { @@ -431,7 +428,7 @@ async fn execute_get( request: Arc, providers: &Arc, store: &Store, - mut progress: impl Sink, + mut progress: impl Sink, ) -> anyhow::Result<()> { let remote = store.remote(); let mut providers = providers.find_providers(request.content()); @@ -566,9 +563,7 @@ mod tests { .download(request, Shuffled::new(vec![node1_id, node2_id])) .stream() .await?; - while let Some(item) = progress.next().await { - println!("Got item: {item:?}"); - } + while progress.next().await.is_some() {} assert_eq!(store3.get_bytes(tt1.hash).await?.deref(), b"hello world"); assert_eq!(store3.get_bytes(tt2.hash).await?.deref(), b"hello world 2"); Ok(()) @@ -611,9 +606,7 @@ mod tests { )) .stream() .await?; - while let Some(item) = progress.next().await { - println!("Got item: {item:?}"); - } + while progress.next().await.is_some() {} } if false { let conn = r3.endpoint().connect(node1_addr, crate::ALPN).await?; @@ -675,9 +668,7 @@ mod tests { )) .stream() .await?; - while let Some(item) = progress.next().await { - println!("Got item: {item:?}"); - } + while progress.next().await.is_some() {} Ok(()) } } diff --git a/src/api/remote.rs b/src/api/remote.rs index 623200900..8877c6297 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -8,16 +8,21 @@ use n0_future::{io, Stream, StreamExt}; use n0_snafu::SpanTrace; use nested_enum_utils::common_fields; use ref_cast::RefCast; -use snafu::{Backtrace, IntoError, Snafu}; +use snafu::{Backtrace, IntoError, ResultExt, Snafu}; use super::blobs::{Bitfield, ExportBaoOptions}; use crate::{ api::{blobs::WriteProgress, ApiClient}, - get::{fsm::DecodeError, BadRequestSnafu, GetError, GetResult, LocalFailureSnafu, Stats}, + get::{ + fsm::DecodeError, + get_error::{BadRequestSnafu, LocalFailureSnafu}, + GetError, GetResult, IrohStreamWriter, Stats, + }, protocol::{ GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType, MAX_MESSAGE_SIZE, }, + provider::events::{ClientResult, ProgressError}, util::sink::{Sink, TokioMpscSenderSink}, }; @@ -94,8 +99,7 @@ impl GetProgress { pub async fn complete(self) -> GetResult { just_result(self.stream()).await.unwrap_or_else(|| { - Err(LocalFailureSnafu - .into_error(anyhow::anyhow!("stream closed without result").into())) + Err(LocalFailureSnafu.into_error(anyhow::anyhow!("stream closed without result"))) }) } } @@ -478,9 +482,7 @@ impl Remote { let content = content.into(); let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); - let sink = TokioMpscSenderSink(tx) - .with_map(GetProgressItem::Progress) - .with_map_err(io::Error::other); + let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress); let this = self.clone(); let fut = async move { let res = this.fetch_sink(conn, content, sink).await.into(); @@ -503,13 +505,13 @@ impl Remote { &self, mut conn: impl GetConnection, content: impl Into, - progress: impl Sink, + progress: impl Sink, ) -> GetResult { let content = content.into(); let local = self .local(content) .await - .map_err(|e| LocalFailureSnafu.into_error(e.into()))?; + .map_err(|e: anyhow::Error| LocalFailureSnafu.into_error(e))?; if local.is_complete() { return Ok(Default::default()); } @@ -517,7 +519,7 @@ impl Remote { let conn = conn .connection() .await - .map_err(|e| LocalFailureSnafu.into_error(e.into()))?; + .map_err(|e| LocalFailureSnafu.into_error(e))?; let stats = self.execute_get_sink(&conn, request, progress).await?; Ok(stats) } @@ -556,9 +558,7 @@ impl Remote { pub fn execute_push(&self, conn: Connection, request: PushRequest) -> PushProgress { let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); - let sink = TokioMpscSenderSink(tx) - .with_map(PushProgressItem::Progress) - .with_map_err(io::Error::other); + let sink = TokioMpscSenderSink(tx).with_map(PushProgressItem::Progress); let this = self.clone(); let fut = async move { let res = this.execute_push_sink(conn, request, sink).await.into(); @@ -577,7 +577,7 @@ impl Remote { &self, conn: Connection, request: PushRequest, - progress: impl Sink, + progress: impl Sink, ) -> anyhow::Result { let hash = request.hash; debug!(%hash, "pushing"); @@ -593,15 +593,16 @@ impl Remote { let mut request_ranges = request.ranges.iter_infinite(); let root = request.hash; let root_ranges = request_ranges.next().expect("infinite iterator"); + let mut send = IrohStreamWriter(send); if !root_ranges.is_empty() { self.store() .export_bao(root, root_ranges.clone()) - .write_quinn_with_progress(&mut send, &mut context, &root, 0) + .write_with_progress(&mut send, &mut context, &root, 0) .await?; } if request.ranges.is_blob() { // we are done - send.finish()?; + send.0.finish()?; return Ok(Default::default()); } let hash_seq = self.store().get_bytes(root).await?; @@ -612,16 +613,11 @@ impl Remote { if !child_ranges.is_empty() { self.store() .export_bao(child_hash, child_ranges.clone()) - .write_quinn_with_progress( - &mut send, - &mut context, - &child_hash, - (child + 1) as u64, - ) + .write_with_progress(&mut send, &mut context, &child_hash, (child + 1) as u64) .await?; } } - send.finish()?; + send.0.finish()?; Ok(Default::default()) } @@ -632,9 +628,7 @@ impl Remote { pub fn execute_get_with_opts(&self, conn: Connection, request: GetRequest) -> GetProgress { let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); - let sink = TokioMpscSenderSink(tx) - .with_map(GetProgressItem::Progress) - .with_map_err(io::Error::other); + let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress); let this = self.clone(); let fut = async move { let res = this.execute_get_sink(&conn, request, sink).await.into(); @@ -658,7 +652,7 @@ impl Remote { &self, conn: &Connection, request: GetRequest, - mut progress: impl Sink, + mut progress: impl Sink, ) -> GetResult { let store = self.store(); let root = request.hash; @@ -690,7 +684,7 @@ impl Remote { .await .map_err(|e| LocalFailureSnafu.into_error(e.into()))?, ) - .map_err(|source| BadRequestSnafu.into_error(source.into()))?; + .context(BadRequestSnafu)?; // let mut hash_seq = LazyHashSeq::new(store.blobs().clone(), root); loop { let at_start_child = match next_child { @@ -721,9 +715,7 @@ impl Remote { pub fn execute_get_many(&self, conn: Connection, request: GetManyRequest) -> GetProgress { let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); - let sink = TokioMpscSenderSink(tx) - .with_map(GetProgressItem::Progress) - .with_map_err(io::Error::other); + let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress); let this = self.clone(); let fut = async move { let res = this.execute_get_many_sink(conn, request, sink).await.into(); @@ -747,7 +739,7 @@ impl Remote { &self, conn: Connection, request: GetManyRequest, - mut progress: impl Sink, + mut progress: impl Sink, ) -> GetResult { let store = self.store(); let hash_seq = request.hashes.iter().copied().collect::(); @@ -762,7 +754,6 @@ impl Remote { Err(at_closing) => break at_closing, }; let offset = at_start_child.offset(); - println!("offset {offset}"); let Some(hash) = hash_seq.get(offset as usize) else { break at_start_child.finish(); }; @@ -884,7 +875,7 @@ async fn get_blob_ranges_impl( header: AtBlobHeader, hash: Hash, store: &Store, - mut progress: impl Sink, + mut progress: impl Sink, ) -> GetResult { let (mut content, size) = header.next().await?; let Some(size) = NonZeroU64::new(size) else { @@ -922,8 +913,7 @@ async fn get_blob_ranges_impl( }; let complete = async move { handle.rx.await.map_err(|e| { - LocalFailureSnafu - .into_error(anyhow::anyhow!("error reading from import stream: {e}").into()) + LocalFailureSnafu.into_error(anyhow::anyhow!("error reading from import stream: {e}")) }) }; let (_, end) = tokio::try_join!(complete, write)?; @@ -1048,11 +1038,20 @@ struct StreamContext { impl WriteProgress for StreamContext where - S: Sink, + S: Sink, { - async fn notify_payload_write(&mut self, _index: u64, _offset: u64, len: usize) { + async fn notify_payload_write( + &mut self, + _index: u64, + _offset: u64, + len: usize, + ) -> ClientResult { self.payload_bytes_sent += len as u64; - self.sender.send(self.payload_bytes_sent).await.ok(); + self.sender + .send(self.payload_bytes_sent) + .await + .map_err(|e| ProgressError::Internal { source: e.into() })?; + Ok(()) } fn log_other_write(&mut self, _len: usize) {} diff --git a/src/get.rs b/src/get.rs index 049ef4855..3accc55f4 100644 --- a/src/get.rs +++ b/src/get.rs @@ -17,30 +17,73 @@ //! //! [iroh]: https://docs.rs/iroh use std::{ - error::Error, fmt::{self, Debug}, + io, time::{Duration, Instant}, }; use anyhow::Result; use bao_tree::{io::fsm::BaoContentItem, ChunkNum}; use fsm::RequestCounters; -use iroh::endpoint::{self, RecvStream, SendStream}; -use iroh_io::TokioStreamReader; +use iroh_io::{AsyncStreamReader, AsyncStreamWriter}; use n0_snafu::SpanTrace; use nested_enum_utils::common_fields; +use quinn::ReadExactError; use serde::{Deserialize, Serialize}; use snafu::{Backtrace, IntoError, ResultExt, Snafu}; +use tokio::io::AsyncWriteExt; use tracing::{debug, error}; use crate::{protocol::ChunkRangesSeq, store::IROH_BLOCK_SIZE, Hash}; mod error; pub mod request; -pub(crate) use error::{BadRequestSnafu, LocalFailureSnafu}; +pub(crate) use error::get_error; pub use error::{GetError, GetResult}; -type WrappedRecvStream = TokioStreamReader; +pub struct IrohStreamWriter(pub iroh::endpoint::SendStream); + +impl AsyncStreamWriter for IrohStreamWriter { + async fn write(&mut self, data: &[u8]) -> io::Result<()> { + Ok(self.0.write_all(data).await?) + } + + async fn write_bytes(&mut self, data: bytes::Bytes) -> io::Result<()> { + Ok(self.0.write_chunk(data).await?) + } + + async fn sync(&mut self) -> io::Result<()> { + self.0.flush().await + } +} + +pub struct IrohStreamReader(pub iroh::endpoint::RecvStream); + +impl AsyncStreamReader for IrohStreamReader { + async fn read(&mut self) -> io::Result<[u8; N]> { + let mut buf = [0u8; N]; + match self.0.read_exact(&mut buf).await { + Ok(()) => Ok(buf), + Err(ReadExactError::ReadError(e)) => Err(e.into()), + Err(ReadExactError::FinishedEarly(_)) => Err(io::ErrorKind::UnexpectedEof.into()), + } + } + + async fn read_bytes(&mut self, len: usize) -> io::Result { + let mut buf = vec![0u8; len]; + match self.0.read_exact(&mut buf).await { + Ok(()) => Ok(buf.into()), + Err(ReadExactError::ReadError(e)) => Err(e.into()), + Err(ReadExactError::FinishedEarly(n)) => { + buf.truncate(n); + Ok(buf.into()) + } + } + } +} + +type DefaultReader = IrohStreamReader; +type DefaultWriter = IrohStreamWriter; /// Stats about the transfer. #[derive( @@ -96,11 +139,11 @@ pub mod fsm { }; use derive_more::From; use iroh::endpoint::Connection; - use iroh_io::{AsyncSliceWriter, AsyncStreamReader, TokioStreamReader}; + use iroh_io::{AsyncSliceWriter, AsyncStreamReader, AsyncStreamWriter}; use super::*; use crate::{ - get::error::BadRequestSnafu, + get::get_error::BadRequestSnafu, protocol::{ GetManyRequest, GetRequest, NonEmptyRequestRangeSpecIter, Request, MAX_MESSAGE_SIZE, }, @@ -130,16 +173,22 @@ pub mod fsm { counters: RequestCounters, ) -> std::result::Result, GetError> { let start = Instant::now(); - let (mut writer, reader) = connection.open_bi().await?; + let (writer, reader) = connection + .open_bi() + .await + .map_err(|e| OpenSnafu.into_error(e.into()))?; + let reader = IrohStreamReader(reader); + let mut writer = IrohStreamWriter(writer); let request = Request::GetMany(request); let request_bytes = postcard::to_stdvec(&request) .map_err(|source| BadRequestSnafu.into_error(source.into()))?; - writer.write_all(&request_bytes).await?; - writer.finish()?; + writer + .write_bytes(request_bytes.into()) + .await + .context(connected_next_error::WriteSnafu)?; let Request::GetMany(request) = request else { unreachable!(); }; - let reader = TokioStreamReader::new(reader); let mut ranges_iter = RangesIter::new(request.ranges.clone()); let first_item = ranges_iter.next(); let misc = Box::new(Misc { @@ -214,10 +263,15 @@ pub mod fsm { } /// Initiate a new bidi stream to use for the get response - pub async fn next(self) -> Result { + pub async fn next(self) -> Result { let start = Instant::now(); - let (writer, reader) = self.connection.open_bi().await?; - let reader = TokioStreamReader::new(reader); + let (writer, reader) = self + .connection + .open_bi() + .await + .map_err(|e| OpenSnafu.into_error(e.into()))?; + let reader = IrohStreamReader(reader); + let writer = IrohStreamWriter(writer); Ok(AtConnected { start, reader, @@ -228,25 +282,41 @@ pub mod fsm { } } + /// Error that you can get from [`AtConnected::next`] + #[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: SpanTrace, + })] + #[allow(missing_docs)] + #[derive(Debug, Snafu)] + #[non_exhaustive] + pub enum InitialNextError { + Open { source: io::Error }, + } + /// State of the get response machine after the handshake has been sent #[derive(Debug)] - pub struct AtConnected { + pub struct AtConnected< + R: AsyncStreamReader = DefaultReader, + W: AsyncStreamWriter = DefaultWriter, + > { start: Instant, - reader: WrappedRecvStream, - writer: SendStream, + reader: R, + writer: W, request: GetRequest, counters: RequestCounters, } /// Possible next states after the handshake has been sent #[derive(Debug, From)] - pub enum ConnectedNext { + pub enum ConnectedNext { /// First response is either a collection or a single blob - StartRoot(AtStartRoot), + StartRoot(AtStartRoot), /// First response is a child - StartChild(AtStartChild), + StartChild(AtStartChild), /// Request is empty - Closing(AtClosing), + Closing(AtClosing), } /// Error that you can get from [`AtConnected::next`] @@ -257,6 +327,7 @@ pub mod fsm { })] #[allow(missing_docs)] #[derive(Debug, Snafu)] + #[snafu(module)] #[non_exhaustive] pub enum ConnectedNextError { /// Error when serializing the request @@ -267,23 +338,33 @@ pub mod fsm { RequestTooBig {}, /// Error when writing the request to the [`SendStream`]. #[snafu(display("write: {source}"))] - Write { source: quinn::WriteError }, - /// Quic connection is closed. - #[snafu(display("closed"))] - Closed { source: quinn::ClosedStream }, - /// A generic io error - #[snafu(transparent)] - Io { source: io::Error }, + Write { source: io::Error }, } - impl AtConnected { + impl AtConnected { + pub fn new( + start: Instant, + reader: R, + writer: W, + request: GetRequest, + counters: RequestCounters, + ) -> Self { + Self { + start, + reader, + writer, + request, + counters, + } + } + /// Send the request and move to the next state /// /// The next state will be either `StartRoot` or `StartChild` depending on whether /// the request requests part of the collection or not. /// /// If the request is empty, this can also move directly to `Finished`. - pub async fn next(self) -> Result { + pub async fn next(self) -> Result, ConnectedNextError> { let Self { start, reader, @@ -295,23 +376,32 @@ pub mod fsm { counters.other_bytes_written += { debug!("sending request"); let wrapped = Request::Get(request); - let request_bytes = postcard::to_stdvec(&wrapped).context(PostcardSerSnafu)?; + let request_bytes = postcard::to_stdvec(&wrapped) + .context(connected_next_error::PostcardSerSnafu)?; let Request::Get(x) = wrapped else { unreachable!(); }; request = x; if request_bytes.len() > MAX_MESSAGE_SIZE { - return Err(RequestTooBigSnafu.build()); + return Err(connected_next_error::RequestTooBigSnafu.build()); } // write the request itself - writer.write_all(&request_bytes).await.context(WriteSnafu)?; - request_bytes.len() as u64 + let len = request_bytes.len() as u64; + writer + .write_bytes(request_bytes.into()) + .await + .context(connected_next_error::WriteSnafu)?; + writer + .sync() + .await + .context(connected_next_error::WriteSnafu)?; + len }; // 2. Finish writing before expecting a response - writer.finish().context(ClosedSnafu)?; + drop(writer); let hash = request.hash; let ranges_iter = RangesIter::new(request.ranges); @@ -348,23 +438,23 @@ pub mod fsm { /// State of the get response when we start reading a collection #[derive(Debug)] - pub struct AtStartRoot { + pub struct AtStartRoot { ranges: ChunkRanges, - reader: TokioStreamReader, + reader: R, misc: Box, hash: Hash, } /// State of the get response when we start reading a child #[derive(Debug)] - pub struct AtStartChild { + pub struct AtStartChild { ranges: ChunkRanges, - reader: TokioStreamReader, + reader: R, misc: Box, offset: u64, } - impl AtStartChild { + impl AtStartChild { /// The offset of the child we are currently reading /// /// This must be used to determine the hash needed to call next. @@ -382,7 +472,7 @@ pub mod fsm { /// Go into the next state, reading the header /// /// This requires passing in the hash of the child for validation - pub fn next(self, hash: Hash) -> AtBlobHeader { + pub fn next(self, hash: Hash) -> AtBlobHeader { AtBlobHeader { reader: self.reader, ranges: self.ranges, @@ -396,12 +486,12 @@ pub mod fsm { /// This is used if you know that there are no more children from having /// read the collection, or when you want to stop reading the response /// early. - pub fn finish(self) -> AtClosing { + pub fn finish(self) -> AtClosing { AtClosing::new(self.misc, self.reader, false) } } - impl AtStartRoot { + impl AtStartRoot { /// The ranges we have requested for the child pub fn ranges(&self) -> &ChunkRanges { &self.ranges @@ -415,7 +505,7 @@ pub mod fsm { /// Go into the next state, reading the header /// /// For the collection we already know the hash, since it was part of the request - pub fn next(self) -> AtBlobHeader { + pub fn next(self) -> AtBlobHeader { AtBlobHeader { reader: self.reader, ranges: self.ranges, @@ -425,16 +515,16 @@ pub mod fsm { } /// Finish the get response without reading further - pub fn finish(self) -> AtClosing { + pub fn finish(self) -> AtClosing { AtClosing::new(self.misc, self.reader, false) } } /// State before reading a size header #[derive(Debug)] - pub struct AtBlobHeader { + pub struct AtBlobHeader { ranges: ChunkRanges, - reader: TokioStreamReader, + reader: R, misc: Box, hash: Hash, } @@ -447,18 +537,16 @@ pub mod fsm { })] #[non_exhaustive] #[derive(Debug, Snafu)] + #[snafu(module)] pub enum AtBlobHeaderNextError { /// Eof when reading the size header /// /// This indicates that the provider does not have the requested data. #[snafu(display("not found"))] NotFound {}, - /// Quinn read error when reading the size header - #[snafu(display("read: {source}"))] - EndpointRead { source: endpoint::ReadError }, /// Generic io error #[snafu(display("io: {source}"))] - Io { source: io::Error }, + Read { source: io::Error }, } impl From for io::Error { @@ -467,25 +555,19 @@ pub mod fsm { AtBlobHeaderNextError::NotFound { .. } => { io::Error::new(io::ErrorKind::UnexpectedEof, cause) } - AtBlobHeaderNextError::EndpointRead { source, .. } => source.into(), - AtBlobHeaderNextError::Io { source, .. } => source, + AtBlobHeaderNextError::Read { source, .. } => source, } } } - impl AtBlobHeader { + impl AtBlobHeader { /// Read the size header, returning it and going into the `Content` state. - pub async fn next(mut self) -> Result<(AtBlobContent, u64), AtBlobHeaderNextError> { + pub async fn next(mut self) -> Result<(AtBlobContent, u64), AtBlobHeaderNextError> { let size = self.reader.read::<8>().await.map_err(|cause| { if cause.kind() == io::ErrorKind::UnexpectedEof { - NotFoundSnafu.build() - } else if let Some(e) = cause - .get_ref() - .and_then(|x| x.downcast_ref::()) - { - EndpointReadSnafu.into_error(e.clone()) + at_blob_header_next_error::NotFoundSnafu.build() } else { - IoSnafu.into_error(cause) + at_blob_header_next_error::ReadSnafu.into_error(cause) } })?; self.misc.other_bytes_read += 8; @@ -506,7 +588,7 @@ pub mod fsm { } /// Drain the response and throw away the result - pub async fn drain(self) -> result::Result { + pub async fn drain(self) -> result::Result, DecodeError> { let (content, _size) = self.next().await?; content.drain().await } @@ -517,7 +599,7 @@ pub mod fsm { /// concatenate the ranges that were requested. pub async fn concatenate_into_vec( self, - ) -> result::Result<(AtEndBlob, Vec), DecodeError> { + ) -> result::Result<(AtEndBlob, Vec), DecodeError> { let (content, _size) = self.next().await?; content.concatenate_into_vec().await } @@ -526,7 +608,7 @@ pub mod fsm { pub async fn write_all( self, data: D, - ) -> result::Result { + ) -> result::Result, DecodeError> { let (content, _size) = self.next().await?; let res = content.write_all(data).await?; Ok(res) @@ -540,7 +622,7 @@ pub mod fsm { self, outboard: Option, data: D, - ) -> result::Result + ) -> result::Result, DecodeError> where D: AsyncSliceWriter, O: OutboardMut, @@ -568,8 +650,8 @@ pub mod fsm { /// State while we are reading content #[derive(Debug)] - pub struct AtBlobContent { - stream: ResponseDecoder, + pub struct AtBlobContent { + stream: ResponseDecoder, misc: Box, } @@ -603,6 +685,7 @@ pub mod fsm { })] #[non_exhaustive] #[derive(Debug, Snafu)] + #[snafu(module)] pub enum DecodeError { /// A chunk was not found or invalid, so the provider stopped sending data #[snafu(display("not found"))] @@ -621,24 +704,25 @@ pub mod fsm { LeafHashMismatch { num: ChunkNum }, /// Error when reading from the stream #[snafu(display("read: {source}"))] - Read { source: endpoint::ReadError }, + Read { source: io::Error }, /// A generic io error #[snafu(display("io: {source}"))] - DecodeIo { source: io::Error }, + Write { source: io::Error }, } impl DecodeError { pub(crate) fn leaf_hash_mismatch(num: ChunkNum) -> Self { - LeafHashMismatchSnafu { num }.build() + decode_error::LeafHashMismatchSnafu { num }.build() } } impl From for DecodeError { fn from(cause: AtBlobHeaderNextError) -> Self { match cause { - AtBlobHeaderNextError::NotFound { .. } => ChunkNotFoundSnafu.build(), - AtBlobHeaderNextError::EndpointRead { source, .. } => ReadSnafu.into_error(source), - AtBlobHeaderNextError::Io { source, .. } => DecodeIoSnafu.into_error(source), + AtBlobHeaderNextError::NotFound { .. } => decode_error::ChunkNotFoundSnafu.build(), + AtBlobHeaderNextError::Read { source, .. } => { + decode_error::ReadSnafu.into_error(source) + } } } } @@ -652,59 +736,50 @@ pub mod fsm { DecodeError::LeafNotFound { .. } => { io::Error::new(io::ErrorKind::UnexpectedEof, cause) } - DecodeError::Read { source, .. } => source.into(), - DecodeError::DecodeIo { source, .. } => source, + DecodeError::Read { source, .. } => source, + DecodeError::Write { source, .. } => source, _ => io::Error::other(cause), } } } - impl From for DecodeError { - fn from(value: io::Error) -> Self { - DecodeIoSnafu.into_error(value) - } - } - impl From for DecodeError { fn from(value: bao_tree::io::DecodeError) -> Self { match value { - bao_tree::io::DecodeError::ParentNotFound(x) => { - ParentNotFoundSnafu { node: x }.build() + bao_tree::io::DecodeError::ParentNotFound(node) => { + decode_error::ParentNotFoundSnafu { node }.build() + } + bao_tree::io::DecodeError::LeafNotFound(num) => { + decode_error::LeafNotFoundSnafu { num }.build() } - bao_tree::io::DecodeError::LeafNotFound(x) => LeafNotFoundSnafu { num: x }.build(), bao_tree::io::DecodeError::ParentHashMismatch(node) => { - ParentHashMismatchSnafu { node }.build() + decode_error::ParentHashMismatchSnafu { node }.build() } - bao_tree::io::DecodeError::LeafHashMismatch(chunk) => { - LeafHashMismatchSnafu { num: chunk }.build() - } - bao_tree::io::DecodeError::Io(cause) => { - if let Some(inner) = cause.get_ref() { - if let Some(e) = inner.downcast_ref::() { - ReadSnafu.into_error(e.clone()) - } else { - DecodeIoSnafu.into_error(cause) - } - } else { - DecodeIoSnafu.into_error(cause) - } + bao_tree::io::DecodeError::LeafHashMismatch(num) => { + decode_error::LeafHashMismatchSnafu { num }.build() } + bao_tree::io::DecodeError::Io(cause) => decode_error::ReadSnafu.into_error(cause), } } } /// The next state after reading a content item #[derive(Debug, From)] - pub enum BlobContentNext { + pub enum BlobContentNext { /// We expect more content - More((AtBlobContent, result::Result)), + More( + ( + AtBlobContent, + result::Result, + ), + ), /// We are done with this blob - Done(AtEndBlob), + Done(AtEndBlob), } - impl AtBlobContent { + impl AtBlobContent { /// Read the next item, either content, an error, or the end of the blob - pub async fn next(self) -> BlobContentNext { + pub async fn next(self) -> BlobContentNext { match self.stream.next().await { ResponseDecoderNext::More((stream, res)) => { let mut next = Self { stream, ..self }; @@ -751,7 +826,7 @@ pub mod fsm { } /// Drain the response and throw away the result - pub async fn drain(self) -> result::Result { + pub async fn drain(self) -> result::Result, DecodeError> { let mut content = self; loop { match content.next().await { @@ -769,7 +844,7 @@ pub mod fsm { /// Concatenate the entire response into a vec pub async fn concatenate_into_vec( self, - ) -> result::Result<(AtEndBlob, Vec), DecodeError> { + ) -> result::Result<(AtEndBlob, Vec), DecodeError> { let mut res = Vec::with_capacity(1024); let mut curr = self; let done = loop { @@ -797,7 +872,7 @@ pub mod fsm { self, mut outboard: Option, mut data: D, - ) -> result::Result + ) -> result::Result, DecodeError> where D: AsyncSliceWriter, O: OutboardMut, @@ -810,11 +885,16 @@ pub mod fsm { match item? { BaoContentItem::Parent(parent) => { if let Some(outboard) = outboard.as_mut() { - outboard.save(parent.node, &parent.pair).await?; + outboard + .save(parent.node, &parent.pair) + .await + .map_err(|e| decode_error::WriteSnafu.into_error(e))?; } } BaoContentItem::Leaf(leaf) => { - data.write_bytes_at(leaf.offset, leaf.data).await?; + data.write_bytes_at(leaf.offset, leaf.data) + .await + .map_err(|e| decode_error::WriteSnafu.into_error(e))?; } } } @@ -826,7 +906,7 @@ pub mod fsm { } /// Write the entire blob to a slice writer. - pub async fn write_all(self, mut data: D) -> result::Result + pub async fn write_all(self, mut data: D) -> result::Result, DecodeError> where D: AsyncSliceWriter, { @@ -838,7 +918,9 @@ pub mod fsm { match item? { BaoContentItem::Parent(_) => {} BaoContentItem::Leaf(leaf) => { - data.write_bytes_at(leaf.offset, leaf.data).await?; + data.write_bytes_at(leaf.offset, leaf.data) + .await + .map_err(|e| decode_error::WriteSnafu.into_error(e))?; } } } @@ -850,30 +932,30 @@ pub mod fsm { } /// Immediately finish the get response without reading further - pub fn finish(self) -> AtClosing { + pub fn finish(self) -> AtClosing { AtClosing::new(self.misc, self.stream.finish(), false) } } /// State after we have read all the content for a blob #[derive(Debug)] - pub struct AtEndBlob { - stream: WrappedRecvStream, + pub struct AtEndBlob { + stream: R, misc: Box, } /// The next state after the end of a blob #[derive(Debug, From)] - pub enum EndBlobNext { + pub enum EndBlobNext { /// Response is expected to have more children - MoreChildren(AtStartChild), + MoreChildren(AtStartChild), /// No more children expected - Closing(AtClosing), + Closing(AtClosing), } - impl AtEndBlob { + impl AtEndBlob { /// Read the next child, or finish - pub fn next(mut self) -> EndBlobNext { + pub fn next(mut self) -> EndBlobNext { if let Some((offset, ranges)) = self.misc.ranges_iter.next() { AtStartChild { reader: self.stream, @@ -890,14 +972,14 @@ pub mod fsm { /// State when finishing the get response #[derive(Debug)] - pub struct AtClosing { + pub struct AtClosing { misc: Box, - reader: WrappedRecvStream, + reader: R, check_extra_data: bool, } - impl AtClosing { - fn new(misc: Box, reader: WrappedRecvStream, check_extra_data: bool) -> Self { + impl AtClosing { + fn new(misc: Box, reader: R, check_extra_data: bool) -> Self { Self { misc, reader, @@ -906,17 +988,14 @@ pub mod fsm { } /// Finish the get response, returning statistics - pub async fn next(self) -> result::Result { + pub async fn next(self) -> result::Result { // Shut down the stream - let reader = self.reader; - let mut reader = reader.into_inner(); + let mut reader = self.reader; if self.check_extra_data { - if let Some(chunk) = reader.read_chunk(8, false).await? { - reader.stop(0u8.into()).ok(); - error!("Received unexpected data from the provider: {chunk:?}"); + let rest = reader.read_bytes(1).await?; + if !rest.is_empty() { + error!("Unexpected extra data at the end of the stream"); } - } else { - reader.stop(0u8.into()).ok(); } Ok(Stats { counters: self.misc.counters, @@ -925,6 +1004,21 @@ pub mod fsm { } } + /// Error that you can get from [`AtBlobHeader::next`] + #[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: SpanTrace, + })] + #[non_exhaustive] + #[derive(Debug, Snafu)] + #[snafu(module)] + pub enum AtClosingNextError { + /// Generic io error + #[snafu(transparent)] + Read { source: io::Error }, + } + #[derive(Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq, Eq)] pub struct RequestCounters { /// payload bytes written @@ -950,71 +1044,3 @@ pub mod fsm { ranges_iter: RangesIter, } } - -/// Error when processing a response -#[common_fields({ - backtrace: Option, - #[snafu(implicit)] - span_trace: SpanTrace, -})] -#[allow(missing_docs)] -#[non_exhaustive] -#[derive(Debug, Snafu)] -pub enum GetResponseError { - /// Error when opening a stream - #[snafu(display("connection: {source}"))] - Connection { source: endpoint::ConnectionError }, - /// Error when writing the handshake or request to the stream - #[snafu(display("write: {source}"))] - Write { source: endpoint::WriteError }, - /// Error when reading from the stream - #[snafu(display("read: {source}"))] - Read { source: endpoint::ReadError }, - /// Error when decoding, e.g. hash mismatch - #[snafu(display("decode: {source}"))] - Decode { source: bao_tree::io::DecodeError }, - /// A generic error - #[snafu(display("generic: {source}"))] - Generic { source: anyhow::Error }, -} - -impl From for GetResponseError { - fn from(cause: postcard::Error) -> Self { - GenericSnafu.into_error(cause.into()) - } -} - -impl From for GetResponseError { - fn from(cause: bao_tree::io::DecodeError) -> Self { - match cause { - bao_tree::io::DecodeError::Io(cause) => { - // try to downcast to specific quinn errors - if let Some(source) = cause.source() { - if let Some(error) = source.downcast_ref::() { - return ConnectionSnafu.into_error(error.clone()); - } - if let Some(error) = source.downcast_ref::() { - return ReadSnafu.into_error(error.clone()); - } - if let Some(error) = source.downcast_ref::() { - return WriteSnafu.into_error(error.clone()); - } - } - GenericSnafu.into_error(cause.into()) - } - _ => DecodeSnafu.into_error(cause), - } - } -} - -impl From for GetResponseError { - fn from(cause: anyhow::Error) -> Self { - GenericSnafu.into_error(cause) - } -} - -impl From for std::io::Error { - fn from(cause: GetResponseError) -> Self { - Self::other(cause) - } -} diff --git a/src/get/error.rs b/src/get/error.rs index 1c3ea9465..5cc44e35b 100644 --- a/src/get/error.rs +++ b/src/get/error.rs @@ -1,102 +1,15 @@ //! Error returned from get operations use std::io; -use iroh::endpoint::{self, ClosedStream}; +use iroh::endpoint::{ConnectionError, ReadError, VarInt, WriteError}; use n0_snafu::SpanTrace; use nested_enum_utils::common_fields; -use quinn::{ConnectionError, ReadError, WriteError}; -use snafu::{Backtrace, IntoError, Snafu}; +use snafu::{Backtrace, Snafu}; -use crate::{ - api::ExportBaoError, - get::fsm::{AtBlobHeaderNextError, ConnectedNextError, DecodeError}, +use crate::get::fsm::{ + AtBlobHeaderNextError, AtClosingNextError, ConnectedNextError, DecodeError, InitialNextError, }; -#[derive(Debug, Snafu)] -pub enum NotFoundCases { - #[snafu(transparent)] - AtBlobHeaderNext { source: AtBlobHeaderNextError }, - #[snafu(transparent)] - Decode { source: DecodeError }, -} - -#[derive(Debug, Snafu)] -pub enum NoncompliantNodeCases { - #[snafu(transparent)] - Connection { source: ConnectionError }, - #[snafu(transparent)] - Decode { source: DecodeError }, -} - -#[derive(Debug, Snafu)] -pub enum RemoteResetCases { - #[snafu(transparent)] - Read { source: ReadError }, - #[snafu(transparent)] - Write { source: WriteError }, - #[snafu(transparent)] - Connection { source: ConnectionError }, -} - -#[derive(Debug, Snafu)] -pub enum BadRequestCases { - #[snafu(transparent)] - Anyhow { source: anyhow::Error }, - #[snafu(transparent)] - Postcard { source: postcard::Error }, - #[snafu(transparent)] - ConnectedNext { source: ConnectedNextError }, -} - -#[derive(Debug, Snafu)] -pub enum LocalFailureCases { - #[snafu(transparent)] - Io { - source: io::Error, - }, - #[snafu(transparent)] - Anyhow { - source: anyhow::Error, - }, - #[snafu(transparent)] - IrpcSend { - source: irpc::channel::SendError, - }, - #[snafu(transparent)] - Irpc { - source: irpc::Error, - }, - #[snafu(transparent)] - ExportBao { - source: ExportBaoError, - }, - TokioSend {}, -} - -impl From> for LocalFailureCases { - fn from(_: tokio::sync::mpsc::error::SendError) -> Self { - LocalFailureCases::TokioSend {} - } -} - -#[derive(Debug, Snafu)] -pub enum IoCases { - #[snafu(transparent)] - Io { source: io::Error }, - #[snafu(transparent)] - ConnectionError { source: endpoint::ConnectionError }, - #[snafu(transparent)] - ReadError { source: endpoint::ReadError }, - #[snafu(transparent)] - WriteError { source: endpoint::WriteError }, - #[snafu(transparent)] - ClosedStream { source: endpoint::ClosedStream }, - #[snafu(transparent)] - ConnectedNextError { source: ConnectedNextError }, - #[snafu(transparent)] - AtBlobHeaderNextError { source: AtBlobHeaderNextError }, -} - /// Failures for a get operation #[common_fields({ backtrace: Option, @@ -105,210 +18,112 @@ pub enum IoCases { })] #[derive(Debug, Snafu)] #[snafu(visibility(pub(crate)))] +#[snafu(module)] pub enum GetError { - /// Hash not found, or a requested chunk for the hash not found. - #[snafu(display("Data for hash not found"))] - NotFound { - #[snafu(source(from(NotFoundCases, Box::new)))] - source: Box, + #[snafu(transparent)] + InitialNext { + source: InitialNextError, }, - /// Remote has reset the connection. - #[snafu(display("Remote has reset the connection"))] - RemoteReset { - #[snafu(source(from(RemoteResetCases, Box::new)))] - source: Box, + #[snafu(transparent)] + ConnectedNext { + source: ConnectedNextError, }, - /// Remote behaved in a non-compliant way. - #[snafu(display("Remote behaved in a non-compliant way"))] - NoncompliantNode { - #[snafu(source(from(NoncompliantNodeCases, Box::new)))] - source: Box, + #[snafu(transparent)] + AtBlobHeaderNext { + source: AtBlobHeaderNextError, }, - - /// Network or IO operation failed. - #[snafu(display("A network or IO operation failed"))] - Io { - #[snafu(source(from(IoCases, Box::new)))] - source: Box, + #[snafu(transparent)] + Decode { + source: DecodeError, }, - /// Our download request is invalid. - #[snafu(display("Our download request is invalid"))] - BadRequest { - #[snafu(source(from(BadRequestCases, Box::new)))] - source: Box, + #[snafu(transparent)] + IrpcSend { + source: irpc::channel::SendError, + }, + #[snafu(transparent)] + AtClosingNext { + source: AtClosingNextError, }, - /// Operation failed on the local node. - #[snafu(display("Operation failed on the local node"))] LocalFailure { - #[snafu(source(from(LocalFailureCases, Box::new)))] - source: Box, + source: anyhow::Error, + }, + BadRequest { + source: anyhow::Error, }, } -pub type GetResult = std::result::Result; - -impl From for GetError { - fn from(value: irpc::channel::SendError) -> Self { - LocalFailureSnafu.into_error(value.into()) - } -} - -impl From> for GetError { - fn from(value: tokio::sync::mpsc::error::SendError) -> Self { - LocalFailureSnafu.into_error(value.into()) - } -} - -impl From for GetError { - fn from(value: endpoint::ConnectionError) -> Self { - // explicit match just to be sure we are taking everything into account - use endpoint::ConnectionError; - match value { - e @ ConnectionError::VersionMismatch => { - // > The peer doesn't implement any supported version - // unsupported version is likely a long time error, so this peer is not usable - NoncompliantNodeSnafu.into_error(e.into()) - } - e @ ConnectionError::TransportError(_) => { - // > The peer violated the QUIC specification as understood by this implementation - // bad peer we don't want to keep around - NoncompliantNodeSnafu.into_error(e.into()) - } - e @ ConnectionError::ConnectionClosed(_) => { - // > The peer's QUIC stack aborted the connection automatically - // peer might be disconnecting or otherwise unavailable, drop it - IoSnafu.into_error(e.into()) - } - e @ ConnectionError::ApplicationClosed(_) => { - // > The peer closed the connection - // peer might be disconnecting or otherwise unavailable, drop it - IoSnafu.into_error(e.into()) - } - e @ ConnectionError::Reset => { - // > The peer is unable to continue processing this connection, usually due to having restarted - RemoteResetSnafu.into_error(e.into()) - } - e @ ConnectionError::TimedOut => { - // > Communication with the peer has lapsed for longer than the negotiated idle timeout - IoSnafu.into_error(e.into()) - } - e @ ConnectionError::LocallyClosed => { - // > The local application closed the connection - // TODO(@divma): don't see how this is reachable but let's just not use the peer - IoSnafu.into_error(e.into()) - } - e @ ConnectionError::CidsExhausted => { - // > The connection could not be created because not enough of the CID space - // > is available - IoSnafu.into_error(e.into()) - } - } - } -} - -impl From for GetError { - fn from(value: endpoint::ReadError) -> Self { - use endpoint::ReadError; - match value { - e @ ReadError::Reset(_) => RemoteResetSnafu.into_error(e.into()), - ReadError::ConnectionLost(conn_error) => conn_error.into(), - ReadError::ClosedStream - | ReadError::IllegalOrderedRead - | ReadError::ZeroRttRejected => { - // all these errors indicate the peer is not usable at this moment - IoSnafu.into_error(value.into()) - } +impl GetError { + pub fn iroh_error_code(&self) -> Option { + if let Some(ReadError::Reset(code)) = self + .remote_read() + .and_then(|source| source.get_ref()) + .and_then(|e| e.downcast_ref::()) + { + Some(*code) + } else if let Some(WriteError::Stopped(code)) = self + .remote_write() + .and_then(|source| source.get_ref()) + .and_then(|e| e.downcast_ref::()) + { + Some(*code) + } else if let Some(ConnectionError::ApplicationClosed(ac)) = self + .open() + .and_then(|source| source.get_ref()) + .and_then(|e| e.downcast_ref::()) + { + Some(ac.error_code) + } else { + None } } -} -impl From for GetError { - fn from(value: ClosedStream) -> Self { - IoSnafu.into_error(value.into()) - } -} -impl From for GetError { - fn from(value: quinn::WriteError) -> Self { - use quinn::WriteError; - match value { - e @ WriteError::Stopped(_) => RemoteResetSnafu.into_error(e.into()), - WriteError::ConnectionLost(conn_error) => conn_error.into(), - WriteError::ClosedStream | WriteError::ZeroRttRejected => { - // all these errors indicate the peer is not usable at this moment - IoSnafu.into_error(value.into()) - } + pub fn remote_write(&self) -> Option<&io::Error> { + match self { + Self::ConnectedNext { + source: ConnectedNextError::Write { source, .. }, + .. + } => Some(source), + _ => None, } } -} -impl From for GetError { - fn from(value: crate::get::fsm::ConnectedNextError) -> Self { - use crate::get::fsm::ConnectedNextError::*; - match value { - e @ PostcardSer { .. } => { - // serialization errors indicate something wrong with the request itself - BadRequestSnafu.into_error(e.into()) - } - e @ RequestTooBig { .. } => { - // request will never be sent, drop it - BadRequestSnafu.into_error(e.into()) - } - Write { source, .. } => source.into(), - Closed { source, .. } => source.into(), - e @ Io { .. } => { - // io errors are likely recoverable - IoSnafu.into_error(e.into()) - } + pub fn open(&self) -> Option<&io::Error> { + match self { + Self::InitialNext { + source: InitialNextError::Open { source, .. }, + .. + } => Some(source), + _ => None, } } -} -impl From for GetError { - fn from(value: crate::get::fsm::AtBlobHeaderNextError) -> Self { - use crate::get::fsm::AtBlobHeaderNextError::*; - match value { - e @ NotFound { .. } => { - // > This indicates that the provider does not have the requested data. - // peer might have the data later, simply retry it - NotFoundSnafu.into_error(e.into()) - } - EndpointRead { source, .. } => source.into(), - e @ Io { .. } => { - // io errors are likely recoverable - IoSnafu.into_error(e.into()) - } + pub fn remote_read(&self) -> Option<&io::Error> { + match self { + Self::AtBlobHeaderNext { + source: AtBlobHeaderNextError::Read { source, .. }, + .. + } => Some(source), + Self::Decode { + source: DecodeError::Read { source, .. }, + .. + } => Some(source), + Self::AtClosingNext { + source: AtClosingNextError::Read { source, .. }, + .. + } => Some(source), + _ => None, } } -} - -impl From for GetError { - fn from(value: crate::get::fsm::DecodeError) -> Self { - use crate::get::fsm::DecodeError::*; - match value { - e @ ChunkNotFound { .. } => NotFoundSnafu.into_error(e.into()), - e @ ParentNotFound { .. } => NotFoundSnafu.into_error(e.into()), - e @ LeafNotFound { .. } => NotFoundSnafu.into_error(e.into()), - e @ ParentHashMismatch { .. } => { - // TODO(@divma): did the peer sent wrong data? is it corrupted? did we sent a wrong - // request? - NoncompliantNodeSnafu.into_error(e.into()) - } - e @ LeafHashMismatch { .. } => { - // TODO(@divma): did the peer sent wrong data? is it corrupted? did we sent a wrong - // request? - NoncompliantNodeSnafu.into_error(e.into()) - } - Read { source, .. } => source.into(), - DecodeIo { source, .. } => source.into(), + pub fn local_write(&self) -> Option<&io::Error> { + match self { + Self::Decode { + source: DecodeError::Write { source, .. }, + .. + } => Some(source), + _ => None, } } } -impl From for GetError { - fn from(value: std::io::Error) -> Self { - // generally consider io errors recoverable - // we might want to revisit this at some point - IoSnafu.into_error(value.into()) - } -} +pub type GetResult = std::result::Result; diff --git a/src/get/request.rs b/src/get/request.rs index 98563057e..c1dc034d3 100644 --- a/src/get/request.rs +++ b/src/get/request.rs @@ -25,7 +25,7 @@ use tokio::sync::mpsc; use super::{fsm, GetError, GetResult, Stats}; use crate::{ - get::error::{BadRequestSnafu, LocalFailureSnafu}, + get::get_error::{BadRequestSnafu, LocalFailureSnafu}, hashseq::HashSeq, protocol::{ChunkRangesExt, ChunkRangesSeq, GetRequest}, Hash, HashAndFormat, @@ -58,7 +58,7 @@ impl GetBlobResult { let mut parts = Vec::new(); let stats = loop { let Some(item) = self.next().await else { - return Err(LocalFailureSnafu.into_error(anyhow::anyhow!("unexpected end").into())); + return Err(LocalFailureSnafu.into_error(anyhow::anyhow!("unexpected end"))); }; match item { GetBlobItem::Item(item) => { @@ -238,11 +238,11 @@ pub async fn get_hash_seq_and_sizes( let (at_blob_content, size) = at_start_root.next().await?; // check the size to avoid parsing a maliciously large hash seq if size > max_size { - return Err(BadRequestSnafu.into_error(anyhow::anyhow!("size too large").into())); + return Err(BadRequestSnafu.into_error(anyhow::anyhow!("size too large"))); } let (mut curr, hash_seq) = at_blob_content.concatenate_into_vec().await?; - let hash_seq = HashSeq::try_from(Bytes::from(hash_seq)) - .map_err(|e| BadRequestSnafu.into_error(e.into()))?; + let hash_seq = + HashSeq::try_from(Bytes::from(hash_seq)).map_err(|e| BadRequestSnafu.into_error(e))?; let mut sizes = Vec::with_capacity(hash_seq.len()); let closing = loop { match curr.next() { diff --git a/src/net_protocol.rs b/src/net_protocol.rs index 3e7d9582e..47cda5344 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -36,22 +36,16 @@ //! # } //! ``` -use std::{fmt::Debug, future::Future, ops::Deref, sync::Arc}; +use std::{fmt::Debug, ops::Deref, sync::Arc}; use iroh::{ endpoint::Connection, protocol::{AcceptError, ProtocolHandler}, Endpoint, Watcher, }; -use tokio::sync::mpsc; use tracing::error; -use crate::{ - api::Store, - provider::{Event, EventSender}, - ticket::BlobTicket, - HashAndFormat, -}; +use crate::{api::Store, provider::events::EventSender, ticket::BlobTicket, HashAndFormat}; #[derive(Debug)] pub(crate) struct BlobsInner { @@ -75,12 +69,12 @@ impl Deref for BlobsProtocol { } impl BlobsProtocol { - pub fn new(store: &Store, endpoint: Endpoint, events: Option>) -> Self { + pub fn new(store: &Store, endpoint: Endpoint, events: Option) -> Self { Self { inner: Arc::new(BlobsInner { store: store.clone(), endpoint, - events: EventSender::new(events), + events: events.unwrap_or(EventSender::DEFAULT), }), } } @@ -106,25 +100,16 @@ impl BlobsProtocol { } impl ProtocolHandler for BlobsProtocol { - fn accept( - &self, - conn: Connection, - ) -> impl Future> + Send { + async fn accept(&self, conn: Connection) -> std::result::Result<(), AcceptError> { let store = self.store().clone(); let events = self.inner.events.clone(); - - Box::pin(async move { - crate::provider::handle_connection(conn, store, events).await; - Ok(()) - }) + crate::provider::handle_connection(conn, store, events).await; + Ok(()) } - fn shutdown(&self) -> impl Future + Send { - let store = self.store().clone(); - Box::pin(async move { - if let Err(cause) = store.shutdown().await { - error!("error shutting down store: {:?}", cause); - } - }) + async fn shutdown(&self) { + if let Err(cause) = self.store().shutdown().await { + error!("error shutting down store: {:?}", cause); + } } } diff --git a/src/protocol.rs b/src/protocol.rs index 74e0f986d..8aed6539a 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -382,7 +382,7 @@ use bao_tree::{io::round_up_to_chunks, ChunkNum}; use builder::GetRequestBuilder; use derive_more::From; use iroh::endpoint::VarInt; -use irpc::util::AsyncReadVarintExt; +use iroh_io::AsyncStreamReader; use postcard::experimental::max_size::MaxSize; use range_collections::{range_set::RangeSetEntry, RangeSet2}; use serde::{Deserialize, Serialize}; @@ -390,13 +390,19 @@ mod range_spec; pub use bao_tree::ChunkRanges; pub use range_spec::{ChunkRangesSeq, NonEmptyRequestRangeSpecIter, RangeSpec}; use snafu::{GenerateImplicitData, Snafu}; -use tokio::io::AsyncReadExt; -use crate::{api::blobs::Bitfield, provider::CountingReader, BlobFormat, Hash, HashAndFormat}; +use crate::{api::blobs::Bitfield, provider::RecvStreamExt, BlobFormat, Hash, HashAndFormat}; /// Maximum message size is limited to 100MiB for now. pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; +/// Error code for a permission error +pub const ERR_PERMISSION: VarInt = VarInt::from_u32(1u32); +/// Error code for when a request is aborted due to a rate limit +pub const ERR_LIMIT: VarInt = VarInt::from_u32(2u32); +/// Error code for when a request is aborted due to internal error +pub const ERR_INTERNAL: VarInt = VarInt::from_u32(3u32); + /// The ALPN used with quic for the iroh blobs protocol. pub const ALPN: &[u8] = b"/iroh-bytes/4"; @@ -441,9 +447,7 @@ pub enum RequestType { } impl Request { - pub async fn read_async( - reader: &mut CountingReader<&mut iroh::endpoint::RecvStream>, - ) -> io::Result { + pub async fn read_async(reader: &mut R) -> io::Result<(Self, usize)> { let request_type = reader.read_u8().await?; let request_type: RequestType = postcard::from_bytes(std::slice::from_ref(&request_type)) .map_err(|_| { @@ -453,22 +457,31 @@ impl Request { ) })?; Ok(match request_type { - RequestType::Get => reader - .read_to_end_as::(MAX_MESSAGE_SIZE) - .await? - .into(), - RequestType::GetMany => reader - .read_to_end_as::(MAX_MESSAGE_SIZE) - .await? - .into(), - RequestType::Observe => reader - .read_to_end_as::(MAX_MESSAGE_SIZE) - .await? - .into(), - RequestType::Push => reader - .read_length_prefixed::(MAX_MESSAGE_SIZE) - .await? - .into(), + RequestType::Get => { + let (r, size) = reader + .read_to_end_as::(MAX_MESSAGE_SIZE) + .await?; + (r.into(), size) + } + RequestType::GetMany => { + let (r, size) = reader + .read_to_end_as::(MAX_MESSAGE_SIZE) + .await?; + (r.into(), size) + } + RequestType::Observe => { + let (r, size) = reader + .read_to_end_as::(MAX_MESSAGE_SIZE) + .await?; + (r.into(), size) + } + RequestType::Push => { + let r = reader + .read_length_prefixed::(MAX_MESSAGE_SIZE) + .await?; + let size = postcard::experimental::serialized_size(&r).unwrap(); + (r.into(), size) + } _ => { return Err(io::Error::new( io::ErrorKind::InvalidData, diff --git a/src/provider.rs b/src/provider.rs index 61af8f6e1..a79e0ad8f 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -5,131 +5,46 @@ //! handler with an [`iroh::Endpoint`](iroh::protocol::Router). use std::{ fmt::Debug, + future::Future, io, - ops::{Deref, DerefMut}, - pin::Pin, - task::Poll, - time::Duration, + time::{Duration, Instant}, }; -use anyhow::{Context, Result}; +use anyhow::Result; use bao_tree::ChunkRanges; -use iroh::{ - endpoint::{self, RecvStream, SendStream}, - NodeId, -}; -use irpc::channel::oneshot; +use iroh::endpoint; +use iroh_io::{AsyncStreamReader, AsyncStreamWriter}; use n0_future::StreamExt; -use serde::de::DeserializeOwned; -use tokio::{io::AsyncRead, select, sync::mpsc}; -use tracing::{debug, debug_span, error, warn, Instrument}; +use quinn::{ConnectionError, VarInt}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use snafu::Snafu; +use tokio::select; +use tracing::{debug, debug_span, warn, Instrument}; use crate::{ - api::{self, blobs::Bitfield, Store}, + api::{ + blobs::{Bitfield, WriteProgress}, + ExportBaoError, ExportBaoResult, RequestError, Store, + }, + get::{IrohStreamReader, IrohStreamWriter}, hashseq::HashSeq, protocol::{ - ChunkRangesSeq, GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, - Request, + GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL, + }, + provider::events::{ + ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError, + RequestTracker, }, Hash, }; +pub mod events; +use events::EventSender; -/// Provider progress events, to keep track of what the provider is doing. -/// -/// ClientConnected -> -/// (GetRequestReceived -> (TransferStarted -> TransferProgress*n)*n -> (TransferCompleted | TransferAborted))*n -> -/// ConnectionClosed -#[derive(Debug)] -pub enum Event { - /// A new client connected to the provider. - ClientConnected { - connection_id: u64, - node_id: NodeId, - permitted: oneshot::Sender, - }, - /// Connection closed. - ConnectionClosed { connection_id: u64 }, - /// A new get request was received from the provider. - GetRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The root hash of the request. - hash: Hash, - /// The exact query ranges of the request. - ranges: ChunkRangesSeq, - }, - /// A new get request was received from the provider. - GetManyRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The root hash of the request. - hashes: Vec, - /// The exact query ranges of the request. - ranges: ChunkRangesSeq, - }, - /// A new get request was received from the provider. - PushRequestReceived { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The root hash of the request. - hash: Hash, - /// The exact query ranges of the request. - ranges: ChunkRangesSeq, - /// Complete this to permit the request. - permitted: oneshot::Sender, - }, - /// Transfer for the nth blob started. - TransferStarted { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The index of the blob in the request. 0 for the first blob or for raw blob requests. - index: u64, - /// The hash of the blob. This is the hash of the request for the first blob, the child hash (index-1) for subsequent blobs. - hash: Hash, - /// The size of the blob. This is the full size of the blob, not the size we are sending. - size: u64, - }, - /// Progress of the transfer. - TransferProgress { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// The index of the blob in the request. 0 for the first blob or for raw blob requests. - index: u64, - /// The end offset of the chunk that was sent. - end_offset: u64, - }, - /// Entire transfer completed. - TransferCompleted { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// Statistics about the transfer. - stats: Box, - }, - /// Entire transfer aborted - TransferAborted { - /// The connection id. Multiple requests can be sent over the same connection. - connection_id: u64, - /// The request id. There is a new id for each request. - request_id: u64, - /// Statistics about the part of the transfer that was aborted. - stats: Option>, - }, -} +type DefaultWriter = IrohStreamWriter; +type DefaultReader = IrohStreamReader; /// Statistics about a successful or failed transfer. -#[derive(Debug)] +#[derive(Debug, Serialize, Deserialize)] pub struct TransferStats { /// The number of bytes sent that are part of the payload. pub payload_bytes_sent: u64, @@ -139,191 +54,245 @@ pub struct TransferStats { pub other_bytes_sent: u64, /// The number of bytes read from the stream. /// - /// This is the size of the request. - pub bytes_read: u64, + /// In most cases this is just the request, for push requests this is + /// request, size header and hash pairs. + pub other_bytes_read: u64, /// Total duration from reading the request to transfer completed. pub duration: Duration, } -/// Read the request from the getter. -/// -/// Will fail if there is an error while reading, or if no valid request is sent. -/// -/// This will read exactly the number of bytes needed for the request, and -/// leave the rest of the stream for the caller to read. -/// -/// It is up to the caller do decide if there should be more data. -pub async fn read_request(reader: &mut ProgressReader) -> Result { - let mut counting = CountingReader::new(&mut reader.inner); - let res = Request::read_async(&mut counting).await?; - reader.bytes_read += counting.read(); - Ok(res) +/// A pair of [`SendStream`] and [`RecvStream`] with additional context data. +#[derive(Debug)] +pub struct StreamPair { + t0: Instant, + connection_id: u64, + request_id: u64, + reader: R, + writer: W, + other_bytes_read: u64, + events: EventSender, } -#[derive(Debug)] -pub struct StreamContext { - /// The connection ID from the connection - pub connection_id: u64, - /// The request ID from the recv stream - pub request_id: u64, - /// The number of bytes written that are part of the payload - pub payload_bytes_sent: u64, - /// The number of bytes written that are not part of the payload - pub other_bytes_sent: u64, - /// The number of bytes read from the stream - pub bytes_read: u64, - /// The progress sender to send events to - pub progress: EventSender, +impl StreamPair { + pub async fn accept( + conn: &endpoint::Connection, + events: EventSender, + ) -> Result { + let (writer, reader) = conn.accept_bi().await?; + Ok(Self::new( + conn.stable_id() as u64, + reader.id().into(), + IrohStreamReader(reader), + IrohStreamWriter(writer), + events, + )) + } } -/// Wrapper for a [`quinn::SendStream`] with additional per request information. -#[derive(Debug)] -pub struct ProgressWriter { - /// The quinn::SendStream to write to - pub inner: SendStream, - pub(crate) context: StreamContext, +impl StreamPair { + pub fn new( + connection_id: u64, + request_id: u64, + reader: R, + writer: W, + events: EventSender, + ) -> Self { + Self { + t0: Instant::now(), + connection_id, + request_id, + reader, + writer, + other_bytes_read: 0, + events, + } + } + + /// Read the request. + /// + /// Will fail if there is an error while reading, or if no valid request is sent. + /// + /// This will read exactly the number of bytes needed for the request, and + /// leave the rest of the stream for the caller to read. + /// + /// It is up to the caller do decide if there should be more data. + pub async fn read_request(&mut self) -> Result { + let (res, size) = Request::read_async(&mut self.reader).await?; + self.other_bytes_read += size as u64; + Ok(res) + } + + /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id + pub async fn into_writer( + mut self, + tracker: RequestTracker, + ) -> Result, io::Error> { + self.reader.expect_eof().await?; + drop(self.reader); + Ok(ProgressWriter::new( + self.writer, + WriterContext { + t0: self.t0, + other_bytes_read: self.other_bytes_read, + payload_bytes_written: 0, + other_bytes_written: 0, + tracker, + }, + )) + } + + pub async fn into_reader( + mut self, + tracker: RequestTracker, + ) -> Result, io::Error> { + self.writer.sync().await?; + drop(self.writer); + Ok(ProgressReader { + inner: self.reader, + context: ReaderContext { + t0: self.t0, + other_bytes_read: self.other_bytes_read, + tracker, + }, + }) + } + + pub async fn get_request( + &self, + f: impl FnOnce() -> GetRequest, + ) -> Result { + self.events + .request(f, self.connection_id, self.request_id) + .await + } + + pub async fn get_many_request( + &self, + f: impl FnOnce() -> GetManyRequest, + ) -> Result { + self.events + .request(f, self.connection_id, self.request_id) + .await + } + + pub async fn push_request( + &self, + f: impl FnOnce() -> PushRequest, + ) -> Result { + self.events + .request(f, self.connection_id, self.request_id) + .await + } + + pub async fn observe_request( + &self, + f: impl FnOnce() -> ObserveRequest, + ) -> Result { + self.events + .request(f, self.connection_id, self.request_id) + .await + } + + pub fn stats(&self) -> TransferStats { + TransferStats { + payload_bytes_sent: 0, + other_bytes_sent: 0, + other_bytes_read: self.other_bytes_read, + duration: self.t0.elapsed(), + } + } } -impl Deref for ProgressWriter { - type Target = StreamContext; +#[derive(Debug)] +struct ReaderContext { + /// The start time of the transfer + t0: Instant, + /// The number of bytes read from the stream + other_bytes_read: u64, + /// Progress tracking for the request + tracker: RequestTracker, +} - fn deref(&self) -> &Self::Target { - &self.context +impl ReaderContext { + fn stats(&self) -> TransferStats { + TransferStats { + payload_bytes_sent: 0, + other_bytes_sent: 0, + other_bytes_read: self.other_bytes_read, + duration: self.t0.elapsed(), + } } } -impl DerefMut for ProgressWriter { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.context +#[derive(Debug)] +pub(crate) struct WriterContext { + /// The start time of the transfer + t0: Instant, + /// The number of bytes read from the stream + other_bytes_read: u64, + /// The number of payload bytes written to the stream + payload_bytes_written: u64, + /// The number of bytes written that are not part of the payload + other_bytes_written: u64, + /// Way to report progress + tracker: RequestTracker, +} + +impl WriterContext { + fn stats(&self) -> TransferStats { + TransferStats { + payload_bytes_sent: self.payload_bytes_written, + other_bytes_sent: self.other_bytes_written, + other_bytes_read: self.other_bytes_read, + duration: self.t0.elapsed(), + } } } -impl StreamContext { - /// Increase the write count due to a non-payload write. - pub fn log_other_write(&mut self, len: usize) { - self.other_bytes_sent += len as u64; +impl WriteProgress for WriterContext { + async fn notify_payload_write(&mut self, _index: u64, offset: u64, len: usize) -> ClientResult { + let len = len as u64; + let end_offset = offset + len; + self.payload_bytes_written += len; + self.tracker.transfer_progress(len, end_offset).await } - pub async fn send_transfer_completed(&mut self) { - self.progress - .send(|| Event::TransferCompleted { - connection_id: self.connection_id, - request_id: self.request_id, - stats: Box::new(TransferStats { - payload_bytes_sent: self.payload_bytes_sent, - other_bytes_sent: self.other_bytes_sent, - bytes_read: self.bytes_read, - duration: Duration::ZERO, - }), - }) - .await; - } - - pub async fn send_transfer_aborted(&mut self) { - self.progress - .send(|| Event::TransferAborted { - connection_id: self.connection_id, - request_id: self.request_id, - stats: Some(Box::new(TransferStats { - payload_bytes_sent: self.payload_bytes_sent, - other_bytes_sent: self.other_bytes_sent, - bytes_read: self.bytes_read, - duration: Duration::ZERO, - })), - }) - .await; + fn log_other_write(&mut self, len: usize) { + self.other_bytes_written += len as u64; } - /// Increase the write count due to a payload write, and notify the progress sender. - /// - /// `index` is the index of the blob in the request. - /// `offset` is the offset in the blob where the write started. - /// `len` is the length of the write. - pub fn notify_payload_write(&mut self, index: u64, offset: u64, len: usize) { - self.payload_bytes_sent += len as u64; - self.progress.try_send(|| Event::TransferProgress { - connection_id: self.connection_id, - request_id: self.request_id, - index, - end_offset: offset + len as u64, - }); - } - - /// Send a get request received event. - /// - /// This sends all the required information to make sense of subsequent events such as - /// [`Event::TransferStarted`] and [`Event::TransferProgress`]. - pub async fn send_get_request_received(&self, hash: &Hash, ranges: &ChunkRangesSeq) { - self.progress - .send(|| Event::GetRequestReceived { - connection_id: self.connection_id, - request_id: self.request_id, - hash: *hash, - ranges: ranges.clone(), - }) - .await; + async fn send_transfer_started(&mut self, index: u64, hash: &Hash, size: u64) { + self.tracker.transfer_started(index, hash, size).await.ok(); } +} - /// Send a get request received event. - /// - /// This sends all the required information to make sense of subsequent events such as - /// [`Event::TransferStarted`] and [`Event::TransferProgress`]. - pub async fn send_get_many_request_received(&self, hashes: &[Hash], ranges: &ChunkRangesSeq) { - self.progress - .send(|| Event::GetManyRequestReceived { - connection_id: self.connection_id, - request_id: self.request_id, - hashes: hashes.to_vec(), - ranges: ranges.clone(), - }) - .await; +/// Wrapper for a [`quinn::SendStream`] with additional per request information. +#[derive(Debug)] +pub struct ProgressWriter { + /// The quinn::SendStream to write to + pub inner: W, + pub(crate) context: WriterContext, +} + +impl ProgressWriter { + fn new(inner: W, context: WriterContext) -> Self { + Self { inner, context } } - /// Authorize a push request. - /// - /// This will send a request to the event sender, and wait for a response if a - /// progress sender is enabled. If not, it will always fail. - /// - /// We want to make accepting push requests very explicit, since this allows - /// remote nodes to add arbitrary data to our store. - #[must_use = "permit should be checked by the caller"] - pub async fn authorize_push_request(&self, hash: &Hash, ranges: &ChunkRangesSeq) -> bool { - let mut wait_for_permit = None; - // send the request, including the permit channel - self.progress - .send(|| { - let (tx, rx) = oneshot::channel(); - wait_for_permit = Some(rx); - Event::PushRequestReceived { - connection_id: self.connection_id, - request_id: self.request_id, - hash: *hash, - ranges: ranges.clone(), - permitted: tx, - } - }) - .await; - // wait for the permit, if necessary - if let Some(wait_for_permit) = wait_for_permit { - // if somebody does not handle the request, they will drop the channel, - // and this will fail immediately. - wait_for_permit.await.unwrap_or(false) - } else { - false - } + async fn transfer_aborted(&self) { + self.context + .tracker + .transfer_aborted(|| Box::new(self.context.stats())) + .await + .ok(); } - /// Send a transfer started event. - pub async fn send_transfer_started(&self, index: u64, hash: &Hash, size: u64) { - self.progress - .send(|| Event::TransferStarted { - connection_id: self.connection_id, - request_id: self.request_id, - index, - hash: *hash, - size, - }) - .await; + async fn transfer_completed(&self) { + self.context + .tracker + .transfer_completed(|| Box::new(self.context.stats())) + .await + .ok(); } } @@ -340,122 +309,147 @@ pub async fn handle_connection( warn!("failed to get node id"); return; }; - if !progress - .authorize_client_connection(connection_id, node_id) + if let Err(cause) = progress + .client_connected(|| ClientConnected { + connection_id, + node_id, + }) .await { - debug!("client not authorized to connect"); + connection.close(cause.code(), cause.reason()); + debug!("closing connection: {cause}"); return; } - while let Ok((writer, reader)) = connection.accept_bi().await { - // The stream ID index is used to identify this request. Requests only arrive in - // bi-directional RecvStreams initiated by the client, so this uniquely identifies them. - let request_id = reader.id().index(); - let span = debug_span!("stream", stream_id = %request_id); + while let Ok(context) = StreamPair::accept(&connection, progress.clone()).await { + let span = debug_span!("stream", stream_id = %context.request_id); let store = store.clone(); - let mut writer = ProgressWriter { - inner: writer, - context: StreamContext { - connection_id, - request_id, - payload_bytes_sent: 0, - other_bytes_sent: 0, - bytes_read: 0, - progress: progress.clone(), - }, - }; - tokio::spawn( - async move { - match handle_stream(store, reader, &mut writer).await { - Ok(()) => { - writer.send_transfer_completed().await; - } - Err(err) => { - warn!("error: {err:#?}",); - writer.send_transfer_aborted().await; - } - } - } - .instrument(span), - ); + tokio::spawn(handle_stream(context, store).instrument(span)); } progress - .send(Event::ConnectionClosed { connection_id }) - .await; + .connection_closed(|| ConnectionClosed { connection_id }) + .await + .ok(); } .instrument(span) .await } -async fn handle_stream( - store: Store, - reader: RecvStream, - writer: &mut ProgressWriter, -) -> Result<()> { - // 1. Decode the request. - debug!("reading request"); - let mut reader = ProgressReader { - inner: reader, - context: StreamContext { - connection_id: writer.connection_id, - request_id: writer.request_id, - payload_bytes_sent: 0, - other_bytes_sent: 0, - bytes_read: 0, - progress: writer.progress.clone(), - }, - }; - let request = match read_request(&mut reader).await { - Ok(request) => request, +/// Describes how to handle errors for a stream. +pub trait ErrorHandler { + type W: AsyncStreamWriter; + type R: AsyncStreamReader; + fn stop(reader: &mut Self::R, code: VarInt) -> impl Future; + fn reset(writer: &mut Self::W, code: VarInt) -> impl Future; +} + +async fn handle_read_request_result( + pair: &mut StreamPair, + r: Result, +) -> Result { + match r { + Ok(x) => Ok(x), Err(e) => { - // todo: increase invalid requests metric counter - return Err(e); + H::reset(&mut pair.writer, e.code()).await; + Err(e) } - }; - - match request { - Request::Get(request) => { - // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - reader.inner.read_to_end(0).await?; - // move the context so we don't lose the bytes read - writer.context = reader.context; - handle_get(store, request, writer).await + } +} +async fn handle_write_result( + writer: &mut ProgressWriter, + r: Result, +) -> Result { + match r { + Ok(x) => { + writer.transfer_completed().await; + Ok(x) } - Request::GetMany(request) => { - // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - reader.inner.read_to_end(0).await?; - // move the context so we don't lose the bytes read - writer.context = reader.context; - handle_get_many(store, request, writer).await + Err(e) => { + H::reset(&mut writer.inner, e.code()).await; + writer.transfer_aborted().await; + Err(e) } - Request::Observe(request) => { - // we expect no more bytes after the request, so if there are more bytes, it is an invalid request. - reader.inner.read_to_end(0).await?; - handle_observe(store, request, writer).await + } +} +async fn handle_read_result( + reader: &mut ProgressReader, + r: Result, +) -> Result { + match r { + Ok(x) => { + reader.transfer_completed().await; + Ok(x) + } + Err(e) => { + H::stop(&mut reader.inner, e.code()).await; + reader.transfer_aborted().await; + Err(e) } - Request::Push(request) => { - writer.inner.finish()?; - handle_push(store, request, reader).await + } +} +struct IrohErrorHandler; + +impl ErrorHandler for IrohErrorHandler { + type W = DefaultWriter; + type R = DefaultReader; + + async fn stop(reader: &mut Self::R, code: VarInt) { + reader.0.stop(code).ok(); + } + async fn reset(writer: &mut Self::W, code: VarInt) { + writer.0.reset(code).ok(); + } +} + +pub async fn handle_stream(mut pair: StreamPair, store: Store) -> anyhow::Result<()> { + // 1. Decode the request. + debug!("reading request"); + let request = pair.read_request().await?; + type H = IrohErrorHandler; + + match request { + Request::Get(request) => handle_get::(pair, store, request).await?, + Request::GetMany(request) => handle_get_many::(pair, store, request).await?, + Request::Observe(request) => handle_observe::(pair, store, request).await?, + Request::Push(request) => handle_push::(pair, store, request).await?, + _ => {} + } + Ok(()) +} + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum HandleGetError { + #[snafu(transparent)] + ExportBao { + source: ExportBaoError, + }, + InvalidHashSeq, + InvalidOffset, +} + +impl HasErrorCode for HandleGetError { + fn code(&self) -> VarInt { + match self { + HandleGetError::ExportBao { + source: ExportBaoError::Progress { source, .. }, + } => source.code(), + HandleGetError::InvalidHashSeq => ERR_INTERNAL, + HandleGetError::InvalidOffset => ERR_INTERNAL, + _ => ERR_INTERNAL, } - _ => anyhow::bail!("unsupported request: {request:?}"), - // Request::Push(request) => handle_push(store, request, writer).await, } } /// Handle a single get request. /// /// Requires a database, the request, and a writer. -pub async fn handle_get( +async fn handle_get_impl( store: Store, request: GetRequest, - writer: &mut ProgressWriter, -) -> Result<()> { + writer: &mut ProgressWriter, +) -> Result<(), HandleGetError> { let hash = request.hash; debug!(%hash, "get received request"); - - writer - .send_get_request_received(&hash, &request.ranges) - .await; let mut hash_seq = None; for (offset, ranges) in request.ranges.iter_non_empty_infinite() { if offset == 0 { @@ -470,34 +464,67 @@ pub async fn handle_get( Some(b) => b, None => { let bytes = store.get_bytes(hash).await?; - let hs = HashSeq::try_from(bytes)?; + let hs = + HashSeq::try_from(bytes).map_err(|_| HandleGetError::InvalidHashSeq)?; hash_seq = Some(hs); hash_seq.as_ref().unwrap() } }; - let o = usize::try_from(offset - 1).context("offset too large")?; + let o = usize::try_from(offset - 1).map_err(|_| HandleGetError::InvalidOffset)?; let Some(hash) = hash_seq.get(o) else { break; }; send_blob(&store, offset, hash, ranges.clone(), writer).await?; } } + writer + .inner + .sync() + .await + .map_err(|e| HandleGetError::ExportBao { source: e.into() })?; Ok(()) } +pub async fn handle_get( + mut pair: StreamPair, + store: Store, + request: GetRequest, +) -> anyhow::Result<()> { + let res = pair.get_request(|| request.clone()).await; + let tracker = handle_read_request_result::(&mut pair, res).await?; + let mut writer = pair.into_writer(tracker).await?; + let res = handle_get_impl(store, request, &mut writer).await; + handle_write_result::(&mut writer, res).await?; + Ok(()) +} + +#[derive(Debug, Snafu)] +pub enum HandleGetManyError { + #[snafu(transparent)] + ExportBao { source: ExportBaoError }, +} + +impl HasErrorCode for HandleGetManyError { + fn code(&self) -> VarInt { + match self { + Self::ExportBao { + source: ExportBaoError::Progress { source, .. }, + } => source.code(), + _ => ERR_INTERNAL, + } + } +} + /// Handle a single get request. /// /// Requires a database, the request, and a writer. -pub async fn handle_get_many( +async fn handle_get_many_impl( store: Store, request: GetManyRequest, - writer: &mut ProgressWriter, -) -> Result<()> { + writer: &mut ProgressWriter, +) -> Result<(), HandleGetManyError> { debug!("get_many received request"); - writer - .send_get_many_request_received(&request.hashes, &request.ranges) - .await; let request_ranges = request.ranges.iter_infinite(); for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() { if !ranges.is_empty() { @@ -507,26 +534,61 @@ pub async fn handle_get_many( Ok(()) } +pub async fn handle_get_many( + mut pair: StreamPair, + store: Store, + request: GetManyRequest, +) -> anyhow::Result<()> { + let res = pair.get_many_request(|| request.clone()).await; + let tracker = handle_read_request_result::(&mut pair, res).await?; + let mut writer = pair.into_writer(tracker).await?; + let res = handle_get_many_impl(store, request, &mut writer).await; + handle_write_result::(&mut writer, res).await?; + Ok(()) +} + +#[derive(Debug, Snafu)] +pub enum HandlePushError { + #[snafu(transparent)] + ExportBao { + source: ExportBaoError, + }, + + InvalidHashSeq, + + #[snafu(transparent)] + Request { + source: RequestError, + }, +} + +impl HasErrorCode for HandlePushError { + fn code(&self) -> VarInt { + match self { + Self::ExportBao { + source: ExportBaoError::Progress { source, .. }, + } => source.code(), + _ => ERR_INTERNAL, + } + } +} + /// Handle a single push request. /// /// Requires a database, the request, and a reader. -pub async fn handle_push( +async fn handle_push_impl( store: Store, request: PushRequest, - mut reader: ProgressReader, -) -> Result<()> { + reader: &mut ProgressReader, +) -> Result<(), HandlePushError> { let hash = request.hash; debug!(%hash, "push received request"); - if !reader.authorize_push_request(&hash, &request.ranges).await { - debug!("push request not authorized"); - return Ok(()); - }; let mut request_ranges = request.ranges.iter_infinite(); let root_ranges = request_ranges.next().expect("infinite iterator"); if !root_ranges.is_empty() { // todo: send progress from import_bao_quinn or rename to import_bao_quinn_with_progress store - .import_bao_quinn(hash, root_ranges.clone(), &mut reader.inner) + .import_bao_reader(hash, root_ranges.clone(), &mut reader.inner) .await?; } if request.ranges.is_blob() { @@ -535,52 +597,85 @@ pub async fn handle_push( } // todo: we assume here that the hash sequence is complete. For some requests this might not be the case. We would need `LazyHashSeq` for that, but it is buggy as of now! let hash_seq = store.get_bytes(hash).await?; - let hash_seq = HashSeq::try_from(hash_seq)?; + let hash_seq = HashSeq::try_from(hash_seq).map_err(|_| HandlePushError::InvalidHashSeq)?; for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) { if child_ranges.is_empty() { continue; } store - .import_bao_quinn(child_hash, child_ranges.clone(), &mut reader.inner) + .import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner) .await?; } Ok(()) } +pub async fn handle_push( + mut pair: StreamPair, + store: Store, + request: PushRequest, +) -> anyhow::Result<()> { + let res = pair.push_request(|| request.clone()).await; + let tracker = handle_read_request_result::(&mut pair, res).await?; + let mut reader = pair.into_reader(tracker).await?; + let res = handle_push_impl(store, request, &mut reader).await; + handle_read_result::(&mut reader, res).await?; + Ok(()) +} + /// Send a blob to the client. -pub(crate) async fn send_blob( +pub(crate) async fn send_blob( store: &Store, index: u64, hash: Hash, ranges: ChunkRanges, - writer: &mut ProgressWriter, -) -> api::Result<()> { - Ok(store + writer: &mut ProgressWriter, +) -> ExportBaoResult<()> { + store .export_bao(hash, ranges) - .write_quinn_with_progress(&mut writer.inner, &mut writer.context, &hash, index) - .await?) + .write_with_progress(&mut writer.inner, &mut writer.context, &hash, index) + .await +} + +#[derive(Debug, Snafu)] +pub enum HandleObserveError { + ObserveStreamClosed, + + #[snafu(transparent)] + RemoteClosed { + source: io::Error, + }, +} + +impl HasErrorCode for HandleObserveError { + fn code(&self) -> VarInt { + ERR_INTERNAL + } } /// Handle a single push request. /// /// Requires a database, the request, and a reader. -pub async fn handle_observe( +async fn handle_observe_impl( store: Store, request: ObserveRequest, writer: &mut ProgressWriter, -) -> Result<()> { - let mut stream = store.observe(request.hash).stream().await?; +) -> std::result::Result<(), HandleObserveError> { + let mut stream = store + .observe(request.hash) + .stream() + .await + .map_err(|_| HandleObserveError::ObserveStreamClosed)?; let mut old = stream .next() .await - .ok_or(anyhow::anyhow!("observe stream closed before first value"))?; + .ok_or(HandleObserveError::ObserveStreamClosed)?; // send the initial bitfield send_observe_item(writer, &old).await?; // send updates until the remote loses interest loop { select! { new = stream.next() => { - let new = new.context("observe stream closed")?; + let new = new.ok_or(HandleObserveError::ObserveStreamClosed)?; let diff = old.diff(&new); if diff.is_empty() { continue; @@ -588,7 +683,7 @@ pub async fn handle_observe( send_observe_item(writer, &diff).await?; old = new; } - _ = writer.inner.stopped() => { + _ = writer.inner.0.stopped() => { debug!("observer closed"); break; } @@ -597,166 +692,150 @@ pub async fn handle_observe( Ok(()) } -async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> Result<()> { +async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> io::Result<()> { use irpc::util::AsyncWriteVarintExt; let item = ObserveItem::from(item); - let len = writer.inner.write_length_prefixed(item).await?; - writer.log_other_write(len); + let len = writer.inner.0.write_length_prefixed(item).await?; + writer.context.log_other_write(len); Ok(()) } -/// Helper to lazyly create an [`Event`], in the case that the event creation -/// is expensive and we want to avoid it if the progress sender is disabled. -pub trait LazyEvent { - fn call(self) -> Event; +pub async fn handle_observe>( + mut pair: StreamPair, + store: Store, + request: ObserveRequest, +) -> anyhow::Result<()> { + let res = pair.observe_request(|| request.clone()).await; + let tracker = handle_read_request_result::(&mut pair, res).await?; + let mut writer = pair.into_writer(tracker).await?; + let res = handle_observe_impl(store, request, &mut writer).await; + handle_write_result::(&mut writer, res).await?; + Ok(()) } -impl LazyEvent for T -where - T: FnOnce() -> Event, -{ - fn call(self) -> Event { - self() - } +pub struct ProgressReader { + inner: R, + context: ReaderContext, } -impl LazyEvent for Event { - fn call(self) -> Event { - self +impl ProgressReader { + async fn transfer_aborted(&self) { + self.context + .tracker + .transfer_aborted(|| Box::new(self.context.stats())) + .await + .ok(); } -} -/// A sender for provider events. -#[derive(Debug, Clone)] -pub struct EventSender(EventSenderInner); - -#[derive(Debug, Clone)] -enum EventSenderInner { - Disabled, - Enabled(mpsc::Sender), + async fn transfer_completed(&self) { + self.context + .tracker + .transfer_completed(|| Box::new(self.context.stats())) + .await + .ok(); + } } -impl EventSender { - pub fn new(sender: Option>) -> Self { - match sender { - Some(sender) => Self(EventSenderInner::Enabled(sender)), - None => Self(EventSenderInner::Disabled), +pub(crate) trait RecvStreamExt: AsyncStreamReader { + async fn expect_eof(&mut self) -> io::Result<()> { + match self.read_u8().await { + Ok(_) => Err(io::Error::new( + io::ErrorKind::InvalidData, + "unexpected data", + )), + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(()), + Err(e) => Err(e), } } - /// Send a client connected event, if the progress sender is enabled. - /// - /// This will permit the client to connect if the sender is disabled. - #[must_use = "permit should be checked by the caller"] - pub async fn authorize_client_connection(&self, connection_id: u64, node_id: NodeId) -> bool { - let mut wait_for_permit = None; - self.send(|| { - let (tx, rx) = oneshot::channel(); - wait_for_permit = Some(rx); - Event::ClientConnected { - connection_id, - node_id, - permitted: tx, - } - }) - .await; - if let Some(wait_for_permit) = wait_for_permit { - // if we have events configured, and they drop the channel, we consider that as a no! - // todo: this will be confusing and needs to be properly documented. - wait_for_permit.await.unwrap_or(false) - } else { - true - } + async fn read_u8(&mut self) -> io::Result { + let buf = self.read::<1>().await?; + Ok(buf[0]) } - /// Send an ephemeral event, if the progress sender is enabled. - /// - /// The event will only be created if the sender is enabled. - fn try_send(&self, event: impl LazyEvent) { - match &self.0 { - EventSenderInner::Enabled(sender) => { - let value = event.call(); - sender.try_send(value).ok(); - } - EventSenderInner::Disabled => {} - } + async fn read_to_end_as( + &mut self, + max_size: usize, + ) -> io::Result<(T, usize)> { + let data = self.read_bytes(max_size).await?; + self.expect_eof().await?; + let value = postcard::from_bytes(&data) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok((value, data.len())) } - /// Send a mandatory event, if the progress sender is enabled. - /// - /// The event only be created if the sender is enabled. - async fn send(&self, event: impl LazyEvent) { - match &self.0 { - EventSenderInner::Enabled(sender) => { - let value = event.call(); - if let Err(err) = sender.send(value).await { - error!("failed to send progress event: {:?}", err); - } - } - EventSenderInner::Disabled => {} + async fn read_length_prefixed( + &mut self, + max_size: usize, + ) -> io::Result { + let Some(n) = self.read_varint_u64().await? else { + return Err(io::ErrorKind::UnexpectedEof.into()); + }; + if n > max_size as u64 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "length prefix too large", + )); } + let n = n as usize; + let data = self.read_bytes(n).await?; + let value = postcard::from_bytes(&data) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(value) } -} -pub struct ProgressReader { - inner: RecvStream, - context: StreamContext, -} + /// Reads a u64 varint from an AsyncRead source, using the Postcard/LEB128 format. + /// + /// In Postcard's varint format (LEB128): + /// - Each byte uses 7 bits for the value + /// - The MSB (most significant bit) of each byte indicates if there are more bytes (1) or not (0) + /// - Values are stored in little-endian order (least significant group first) + /// + /// Returns the decoded u64 value. + async fn read_varint_u64(&mut self) -> io::Result> { + let mut result: u64 = 0; + let mut shift: u32 = 0; + + loop { + // We can only shift up to 63 bits (for a u64) + if shift >= 64 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Varint is too large for u64", + )); + } -impl Deref for ProgressReader { - type Target = StreamContext; + // Read a single byte + let res = self.read_u8().await; + if shift == 0 { + if let Err(cause) = res { + if cause.kind() == io::ErrorKind::UnexpectedEof { + return Ok(None); + } else { + return Err(cause); + } + } + } - fn deref(&self) -> &Self::Target { - &self.context - } -} + let byte = res?; -impl DerefMut for ProgressReader { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.context - } -} + // Extract the 7 value bits (bits 0-6, excluding the MSB which is the continuation bit) + let value = (byte & 0x7F) as u64; -pub struct CountingReader { - pub inner: R, - pub read: u64, -} + // Add the bits to our result at the current shift position + result |= value << shift; -impl CountingReader { - pub fn new(inner: R) -> Self { - Self { inner, read: 0 } - } + // If the high bit is not set (0), this is the last byte + if byte & 0x80 == 0 { + break; + } - pub fn read(&self) -> u64 { - self.read - } -} + // Move to the next 7 bits + shift += 7; + } -impl CountingReader<&mut iroh::endpoint::RecvStream> { - pub async fn read_to_end_as(&mut self, max_size: usize) -> io::Result { - let data = self - .inner - .read_to_end(max_size) - .await - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - let value = postcard::from_bytes(&data) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - self.read += data.len() as u64; - Ok(value) + Ok(Some(result)) } } -impl AsyncRead for CountingReader { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - buf: &mut tokio::io::ReadBuf<'_>, - ) -> Poll> { - let this = self.get_mut(); - let result = Pin::new(&mut this.inner).poll_read(cx, buf); - if let Poll::Ready(Ok(())) = result { - this.read += buf.filled().len() as u64; - } - result - } -} +impl RecvStreamExt for R {} diff --git a/src/provider/events.rs b/src/provider/events.rs new file mode 100644 index 000000000..54511f92c --- /dev/null +++ b/src/provider/events.rs @@ -0,0 +1,693 @@ +use std::{fmt::Debug, io, ops::Deref}; + +use irpc::{ + channel::{mpsc, none::NoSender, oneshot}, + rpc_requests, Channels, WithChannels, +}; +use serde::{Deserialize, Serialize}; +use snafu::Snafu; + +use crate::{ + protocol::{ + GetManyRequest, GetRequest, ObserveRequest, PushRequest, ERR_INTERNAL, ERR_LIMIT, + ERR_PERMISSION, + }, + provider::{events::irpc_ext::IrpcClientExt, TransferStats}, + Hash, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum ConnectMode { + /// We don't get notification of connect events at all. + #[default] + None, + /// We get a notification for connect events. + Notify, + /// We get a request for connect events and can reject incoming connections. + Request, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum ObserveMode { + /// We don't get notification of connect events at all. + #[default] + None, + /// We get a notification for connect events. + Notify, + /// We get a request for connect events and can reject incoming connections. + Request, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum RequestMode { + /// We don't get request events at all. + #[default] + None, + /// We get a notification for each request, but no transfer events. + Notify, + /// We get a request for each request, and can reject incoming requests, but no transfer events. + Request, + /// We get a notification for each request as well as detailed transfer events. + NotifyLog, + /// We get a request for each request, and can reject incoming requests. + /// We also get detailed transfer events. + RequestLog, + /// This request type is completely disabled. All requests will be rejected. + Disabled, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum ThrottleMode { + /// We don't get these kinds of events at all + #[default] + None, + /// We call throttle to give the event handler a way to throttle requests + Throttle, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum AbortReason { + RateLimited, + Permission, +} + +#[derive(Debug, Snafu)] +pub enum ProgressError { + Limit, + Permission, + #[snafu(transparent)] + Internal { + source: irpc::Error, + }, +} + +impl From for io::Error { + fn from(value: ProgressError) -> Self { + match value { + ProgressError::Limit => io::ErrorKind::QuotaExceeded.into(), + ProgressError::Permission => io::ErrorKind::PermissionDenied.into(), + ProgressError::Internal { source } => source.into(), + } + } +} + +pub trait HasErrorCode { + fn code(&self) -> quinn::VarInt; +} + +impl HasErrorCode for ProgressError { + fn code(&self) -> quinn::VarInt { + match self { + ProgressError::Limit => ERR_LIMIT, + ProgressError::Permission => ERR_PERMISSION, + ProgressError::Internal { .. } => ERR_INTERNAL, + } + } +} + +impl ProgressError { + pub fn reason(&self) -> &'static [u8] { + match self { + ProgressError::Limit => b"limit", + ProgressError::Permission => b"permission", + ProgressError::Internal { .. } => b"internal", + } + } +} + +impl From for ProgressError { + fn from(value: AbortReason) -> Self { + match value { + AbortReason::RateLimited => ProgressError::Limit, + AbortReason::Permission => ProgressError::Permission, + } + } +} + +impl From for ProgressError { + fn from(value: irpc::channel::RecvError) -> Self { + ProgressError::Internal { + source: value.into(), + } + } +} + +impl From for ProgressError { + fn from(value: irpc::channel::SendError) -> Self { + ProgressError::Internal { + source: value.into(), + } + } +} + +pub type EventResult = Result<(), AbortReason>; +pub type ClientResult = Result<(), ProgressError>; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct EventMask { + /// Connection event mask + pub connected: ConnectMode, + /// Get request event mask + pub get: RequestMode, + /// Get many request event mask + pub get_many: RequestMode, + /// Push request event mask + pub push: RequestMode, + /// Observe request event mask + pub observe: ObserveMode, + /// throttling is somewhat costly, so you can disable it completely + pub throttle: ThrottleMode, +} + +impl Default for EventMask { + fn default() -> Self { + Self::DEFAULT + } +} + +impl EventMask { + /// All event notifications are fully disabled. Push requests are disabled by default. + pub const DEFAULT: Self = Self { + connected: ConnectMode::None, + get: RequestMode::None, + get_many: RequestMode::None, + push: RequestMode::Disabled, + throttle: ThrottleMode::None, + observe: ObserveMode::None, + }; + + /// All event notifications for read-only requests are fully enabled. + /// + /// If you want to enable push requests, which can write to the local store, you + /// need to do it manually. Providing constants that have push enabled would + /// risk misuse. + pub const ALL_READONLY: Self = Self { + connected: ConnectMode::Request, + get: RequestMode::RequestLog, + get_many: RequestMode::RequestLog, + push: RequestMode::Disabled, + throttle: ThrottleMode::Throttle, + observe: ObserveMode::Request, + }; +} + +/// Newtype wrapper that wraps an event so that it is a distinct type for the notify variant. +#[derive(Debug, Serialize, Deserialize)] +pub struct Notify(T); + +impl Deref for Notify { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug, Default, Clone)] +pub struct EventSender { + mask: EventMask, + inner: Option>, +} + +#[derive(Debug, Default)] +enum RequestUpdates { + /// Request tracking was not configured, all ops are no-ops + #[default] + None, + /// Active request tracking, all ops actually send + Active(mpsc::Sender), + /// Disabled request tracking, we just hold on to the sender so it drops + /// once the request is completed or aborted. + Disabled(#[allow(dead_code)] mpsc::Sender), +} + +#[derive(Debug)] +pub struct RequestTracker { + updates: RequestUpdates, + throttle: Option<(irpc::Client, u64, u64)>, +} + +impl RequestTracker { + fn new( + updates: RequestUpdates, + throttle: Option<(irpc::Client, u64, u64)>, + ) -> Self { + Self { updates, throttle } + } + + /// A request tracker that doesn't track anything. + pub const NONE: Self = Self { + updates: RequestUpdates::None, + throttle: None, + }; + + /// Transfer for index `index` started, size `size` + pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> { + if let RequestUpdates::Active(tx) = &self.updates { + tx.send( + TransferStarted { + index, + hash: *hash, + size, + } + .into(), + ) + .await?; + } + Ok(()) + } + + /// Transfer progress for the previously reported blob, end_offset is the new end offset in bytes. + pub async fn transfer_progress(&mut self, len: u64, end_offset: u64) -> ClientResult { + if let RequestUpdates::Active(tx) = &mut self.updates { + tx.try_send(TransferProgress { end_offset }.into()).await?; + } + if let Some((throttle, connection_id, request_id)) = &self.throttle { + throttle + .rpc(Throttle { + connection_id: *connection_id, + request_id: *request_id, + size: len, + }) + .await??; + } + Ok(()) + } + + /// Transfer completed for the previously reported blob. + pub async fn transfer_completed(&self, f: impl Fn() -> Box) -> irpc::Result<()> { + if let RequestUpdates::Active(tx) = &self.updates { + tx.send(TransferCompleted { stats: f() }.into()).await?; + } + Ok(()) + } + + /// Transfer aborted for the previously reported blob. + pub async fn transfer_aborted(&self, f: impl Fn() -> Box) -> irpc::Result<()> { + if let RequestUpdates::Active(tx) = &self.updates { + tx.send(TransferAborted { stats: f() }.into()).await?; + } + Ok(()) + } +} + +/// Client for progress notifications. +/// +/// For most event types, the client can be configured to either send notifications or requests that +/// can have a response. +impl EventSender { + /// A client that does not send anything. + pub const DEFAULT: Self = Self { + mask: EventMask::DEFAULT, + inner: None, + }; + + pub fn new(client: tokio::sync::mpsc::Sender, mask: EventMask) -> Self { + Self { + mask, + inner: Some(irpc::Client::from(client)), + } + } + + pub fn channel( + capacity: usize, + mask: EventMask, + ) -> (Self, tokio::sync::mpsc::Receiver) { + let (tx, rx) = tokio::sync::mpsc::channel(capacity); + (Self::new(tx, mask), rx) + } + + /// Log request events at trace level. + pub fn tracing(&self, mask: EventMask) -> Self { + use tracing::trace; + let (tx, mut rx) = tokio::sync::mpsc::channel(32); + n0_future::task::spawn(async move { + fn log_request_events( + mut rx: irpc::channel::mpsc::Receiver, + connection_id: u64, + request_id: u64, + ) { + n0_future::task::spawn(async move { + while let Ok(Some(update)) = rx.recv().await { + trace!(%connection_id, %request_id, "{update:?}"); + } + }); + } + while let Some(msg) = rx.recv().await { + match msg { + ProviderMessage::ClientConnected(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + } + ProviderMessage::ClientConnectedNotify(msg) => { + trace!("{:?}", msg.inner); + } + ProviderMessage::ConnectionClosed(msg) => { + trace!("{:?}", msg.inner); + } + ProviderMessage::GetRequestReceived(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::GetRequestReceivedNotify(msg) => { + trace!("{:?}", msg.inner); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::GetManyRequestReceived(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::GetManyRequestReceivedNotify(msg) => { + trace!("{:?}", msg.inner); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::PushRequestReceived(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::PushRequestReceivedNotify(msg) => { + trace!("{:?}", msg.inner); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::ObserveRequestReceived(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::ObserveRequestReceivedNotify(msg) => { + trace!("{:?}", msg.inner); + log_request_events(msg.rx, msg.inner.connection_id, msg.inner.request_id); + } + ProviderMessage::Throttle(msg) => { + trace!("{:?}", msg.inner); + msg.tx.send(Ok(())).await.ok(); + } + } + } + }); + Self { + mask, + inner: Some(irpc::Client::from(tx)), + } + } + + /// A new client has been connected. + pub async fn client_connected(&self, f: impl Fn() -> ClientConnected) -> ClientResult { + if let Some(client) = &self.inner { + match self.mask.connected { + ConnectMode::None => {} + ConnectMode::Notify => client.notify(Notify(f())).await?, + ConnectMode::Request => client.rpc(f()).await??, + } + }; + Ok(()) + } + + /// A new client has been connected. + pub async fn connection_closed(&self, f: impl Fn() -> ConnectionClosed) -> ClientResult { + if let Some(client) = &self.inner { + client.notify(f()).await?; + }; + Ok(()) + } + + /// Abstract request, to DRY the 3 to 4 request types. + /// + /// DRYing stuff with lots of bounds is no fun at all... + pub(crate) async fn request( + &self, + f: impl FnOnce() -> Req, + connection_id: u64, + request_id: u64, + ) -> Result + where + ProviderProto: From>, + ProviderMessage: From, ProviderProto>>, + RequestReceived: Channels< + ProviderProto, + Tx = oneshot::Sender, + Rx = mpsc::Receiver, + >, + ProviderProto: From>>, + ProviderMessage: From>, ProviderProto>>, + Notify>: + Channels>, + { + let client = self.inner.as_ref(); + Ok(self.create_tracker(( + match self.mask.get { + RequestMode::None => RequestUpdates::None, + RequestMode::Notify if client.is_some() => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + RequestUpdates::Disabled( + client.unwrap().notify_streaming(Notify(msg), 32).await?, + ) + } + RequestMode::Request if client.is_some() => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?; + // bail out if the request is not allowed + rx.await??; + RequestUpdates::Disabled(tx) + } + RequestMode::NotifyLog if client.is_some() => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + RequestUpdates::Active(client.unwrap().notify_streaming(Notify(msg), 32).await?) + } + RequestMode::RequestLog if client.is_some() => { + let msg = RequestReceived { + request: f(), + connection_id, + request_id, + }; + let (tx, rx) = client.unwrap().client_streaming(msg, 32).await?; + // bail out if the request is not allowed + rx.await??; + RequestUpdates::Active(tx) + } + RequestMode::Disabled => { + return Err(ProgressError::Permission); + } + _ => RequestUpdates::None, + }, + connection_id, + request_id, + ))) + } + + fn create_tracker( + &self, + (updates, connection_id, request_id): (RequestUpdates, u64, u64), + ) -> RequestTracker { + let throttle = match self.mask.throttle { + ThrottleMode::None => None, + ThrottleMode::Throttle => self + .inner + .clone() + .map(|client| (client, connection_id, request_id)), + }; + RequestTracker::new(updates, throttle) + } +} + +#[rpc_requests(message = ProviderMessage)] +#[derive(Debug, Serialize, Deserialize)] +pub enum ProviderProto { + /// A new client connected to the provider. + #[rpc(tx = oneshot::Sender)] + ClientConnected(ClientConnected), + + /// A new client connected to the provider. Notify variant. + #[rpc(tx = NoSender)] + ClientConnectedNotify(Notify), + + /// A client disconnected from the provider. + #[rpc(tx = NoSender)] + ConnectionClosed(ConnectionClosed), + + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] + /// A new get request was received from the provider. + GetRequestReceived(RequestReceived), + + #[rpc(rx = mpsc::Receiver, tx = NoSender)] + /// A new get request was received from the provider. + GetRequestReceivedNotify(Notify>), + + /// A new get request was received from the provider. + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] + GetManyRequestReceived(RequestReceived), + + /// A new get request was received from the provider. + #[rpc(rx = mpsc::Receiver, tx = NoSender)] + GetManyRequestReceivedNotify(Notify>), + + /// A new get request was received from the provider. + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] + PushRequestReceived(RequestReceived), + + /// A new get request was received from the provider. + #[rpc(rx = mpsc::Receiver, tx = NoSender)] + PushRequestReceivedNotify(Notify>), + + /// A new get request was received from the provider. + #[rpc(rx = mpsc::Receiver, tx = oneshot::Sender)] + ObserveRequestReceived(RequestReceived), + + /// A new get request was received from the provider. + #[rpc(rx = mpsc::Receiver, tx = NoSender)] + ObserveRequestReceivedNotify(Notify>), + + #[rpc(tx = oneshot::Sender)] + Throttle(Throttle), +} + +mod proto { + use iroh::NodeId; + use serde::{Deserialize, Serialize}; + + use crate::{provider::TransferStats, Hash}; + + #[derive(Debug, Serialize, Deserialize)] + pub struct ClientConnected { + pub connection_id: u64, + pub node_id: NodeId, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct ConnectionClosed { + pub connection_id: u64, + } + + /// A new get request was received from the provider. + #[derive(Debug, Serialize, Deserialize)] + pub struct RequestReceived { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// The request + pub request: R, + } + + /// Request to throttle sending for a specific request. + #[derive(Debug, Serialize, Deserialize)] + pub struct Throttle { + /// The connection id. Multiple requests can be sent over the same connection. + pub connection_id: u64, + /// The request id. There is a new id for each request. + pub request_id: u64, + /// Size of the chunk to be throttled. This will usually be 16 KiB. + pub size: u64, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct TransferProgress { + /// The end offset of the chunk that was sent. + pub end_offset: u64, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct TransferStarted { + pub index: u64, + pub hash: Hash, + pub size: u64, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct TransferCompleted { + pub stats: Box, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct TransferAborted { + pub stats: Box, + } + + /// Stream of updates for a single request + #[derive(Debug, Serialize, Deserialize, derive_more::From)] + pub enum RequestUpdate { + /// Start of transfer for a blob, mandatory event + Started(TransferStarted), + /// Progress for a blob - optional event + Progress(TransferProgress), + /// Successful end of transfer + Completed(TransferCompleted), + /// Aborted end of transfer + Aborted(TransferAborted), + } +} +pub use proto::*; + +mod irpc_ext { + use std::future::Future; + + use irpc::{ + channel::{mpsc, none::NoSender}, + Channels, RpcMessage, Service, WithChannels, + }; + + pub trait IrpcClientExt { + fn notify_streaming( + &self, + msg: Req, + local_update_cap: usize, + ) -> impl Future>> + where + S: From, + S::Message: From>, + Req: Channels>, + Update: RpcMessage; + } + + impl IrpcClientExt for irpc::Client { + fn notify_streaming( + &self, + msg: Req, + local_update_cap: usize, + ) -> impl Future>> + where + S: From, + S::Message: From>, + Req: Channels>, + Update: RpcMessage, + { + let client = self.clone(); + async move { + let request = client.request().await?; + match request { + irpc::Request::Local(local) => { + let (req_tx, req_rx) = mpsc::channel(local_update_cap); + local + .send((msg, NoSender, req_rx)) + .await + .map_err(irpc::Error::from)?; + Ok(req_tx) + } + irpc::Request::Remote(remote) => { + let (s, _) = remote.write(msg).await?; + Ok(s.into()) + } + } + } + } + } +} diff --git a/src/store/fs/util/entity_manager.rs b/src/store/fs/util/entity_manager.rs index 91a737d76..b0b2898ea 100644 --- a/src/store/fs/util/entity_manager.rs +++ b/src/store/fs/util/entity_manager.rs @@ -1186,10 +1186,6 @@ mod tests { .spawn(id, move |arg| async move { match arg { SpawnArg::Active(state) => { - println!( - "Adding value {} to entity actor with id {:?}", - value, state.id - ); state .with_value(|v| *v = v.wrapping_add(value)) .await diff --git a/src/tests.rs b/src/tests.rs index e7dc823e6..9e5ea89f2 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -16,7 +16,7 @@ use crate::{ hashseq::HashSeq, net_protocol::BlobsProtocol, protocol::{ChunkRangesSeq, GetManyRequest, ObserveRequest, PushRequest}, - provider::Event, + provider::events::{AbortReason, EventMask, EventSender, ProviderMessage, RequestUpdate}, store::{ fs::{ tests::{create_n0_bao, test_data, INTERESTING_SIZES}, @@ -340,27 +340,31 @@ async fn two_nodes_get_many_mem() -> TestResult<()> { fn event_handler( allowed_nodes: impl IntoIterator, -) -> ( - mpsc::Sender, - watch::Receiver, - AbortOnDropHandle<()>, -) { +) -> (EventSender, watch::Receiver, AbortOnDropHandle<()>) { let (count_tx, count_rx) = tokio::sync::watch::channel(0usize); - let (events_tx, mut events_rx) = mpsc::channel::(16); + let (events_tx, mut events_rx) = EventSender::channel(16, EventMask::ALL_READONLY); let allowed_nodes = allowed_nodes.into_iter().collect::>(); let task = AbortOnDropHandle::new(tokio::task::spawn(async move { while let Some(event) = events_rx.recv().await { match event { - Event::ClientConnected { - node_id, permitted, .. - } => { - permitted.send(allowed_nodes.contains(&node_id)).await.ok(); + ProviderMessage::ClientConnected(msg) => { + let res = if allowed_nodes.contains(&msg.inner.node_id) { + Ok(()) + } else { + Err(AbortReason::Permission) + }; + msg.tx.send(res).await.ok(); } - Event::PushRequestReceived { permitted, .. } => { - permitted.send(true).await.ok(); - } - Event::TransferCompleted { .. } => { - count_tx.send_modify(|count| *count += 1); + ProviderMessage::PushRequestReceived(mut msg) => { + msg.tx.send(Ok(())).await.ok(); + let count_tx = count_tx.clone(); + tokio::task::spawn(async move { + while let Ok(Some(update)) = msg.rx.recv().await { + if let RequestUpdate::Completed(_) = update { + count_tx.send_modify(|x| *x += 1); + } + } + }); } _ => {} } @@ -409,7 +413,7 @@ async fn two_nodes_push_blobs_fs() -> TestResult<()> { let (r1, store1, _) = node_test_setup_fs(testdir.path().join("a")).await?; let (events_tx, count_rx, _task) = event_handler([r1.endpoint().node_id()]); let (r2, store2, _) = - node_test_setup_with_events_fs(testdir.path().join("b"), Some(events_tx)).await?; + node_test_setup_with_events_fs(testdir.path().join("b"), events_tx).await?; two_nodes_push_blobs(r1, &store1, r2, &store2, count_rx).await } @@ -418,7 +422,7 @@ async fn two_nodes_push_blobs_mem() -> TestResult<()> { tracing_subscriber::fmt::try_init().ok(); let (r1, store1) = node_test_setup_mem().await?; let (events_tx, count_rx, _task) = event_handler([r1.endpoint().node_id()]); - let (r2, store2) = node_test_setup_with_events_mem(Some(events_tx)).await?; + let (r2, store2) = node_test_setup_with_events_mem(events_tx).await?; two_nodes_push_blobs(r1, &store1, r2, &store2, count_rx).await } @@ -481,30 +485,30 @@ async fn check_presence(store: &Store, sizes: &[usize]) -> TestResult<()> { } pub async fn node_test_setup_fs(db_path: PathBuf) -> TestResult<(Router, FsStore, PathBuf)> { - node_test_setup_with_events_fs(db_path, None).await + node_test_setup_with_events_fs(db_path, EventSender::DEFAULT).await } pub async fn node_test_setup_with_events_fs( db_path: PathBuf, - events: Option>, + events: EventSender, ) -> TestResult<(Router, FsStore, PathBuf)> { let store = crate::store::fs::FsStore::load(&db_path).await?; let ep = Endpoint::builder().bind().await?; - let blobs = BlobsProtocol::new(&store, ep.clone(), events); + let blobs = BlobsProtocol::new(&store, ep.clone(), Some(events)); let router = Router::builder(ep).accept(crate::ALPN, blobs).spawn(); Ok((router, store, db_path)) } pub async fn node_test_setup_mem() -> TestResult<(Router, MemStore)> { - node_test_setup_with_events_mem(None).await + node_test_setup_with_events_mem(EventSender::DEFAULT).await } pub async fn node_test_setup_with_events_mem( - events: Option>, + events: EventSender, ) -> TestResult<(Router, MemStore)> { let store = MemStore::new(); let ep = Endpoint::builder().bind().await?; - let blobs = BlobsProtocol::new(&store, ep.clone(), events); + let blobs = BlobsProtocol::new(&store, ep.clone(), Some(events)); let router = Router::builder(ep).accept(crate::ALPN, blobs).spawn(); Ok((router, store)) } @@ -552,6 +556,7 @@ async fn two_nodes_hash_seq( } #[tokio::test] + async fn two_nodes_hash_seq_fs() -> TestResult<()> { tracing_subscriber::fmt::try_init().ok(); let (_testdir, (r1, store1, _), (r2, store2, _)) = two_node_test_setup_fs().await?; @@ -574,9 +579,7 @@ async fn two_nodes_hash_seq_progress() -> TestResult<()> { let root = add_test_hash_seq(&store1, sizes).await?; let conn = r2.endpoint().connect(addr1, crate::ALPN).await?; let mut stream = store2.remote().fetch(conn, root).stream(); - while let Some(item) = stream.next().await { - println!("{item:?}"); - } + while stream.next().await.is_some() {} check_presence(&store2, &sizes).await?; Ok(()) } @@ -644,9 +647,7 @@ async fn node_serve_blobs() -> TestResult<()> { let expected = test_data(size); let hash = Hash::new(&expected); let mut stream = get::request::get_blob(conn.clone(), hash); - while let Some(item) = stream.next().await { - println!("{item:?}"); - } + while stream.next().await.is_some() {} let actual = get::request::get_blob(conn.clone(), hash).await?; assert_eq!(actual.len(), expected.len(), "size: {size}"); } diff --git a/src/util.rs b/src/util.rs index 3fdaacbca..40abf0343 100644 --- a/src/util.rs +++ b/src/util.rs @@ -363,7 +363,7 @@ pub(crate) mod outboard_with_progress { } pub(crate) mod sink { - use std::{future::Future, io}; + use std::future::Future; use irpc::RpcMessage; @@ -433,10 +433,13 @@ pub(crate) mod sink { pub struct TokioMpscSenderSink(pub tokio::sync::mpsc::Sender); impl Sink for TokioMpscSenderSink { - type Error = tokio::sync::mpsc::error::SendError; + type Error = irpc::channel::SendError; async fn send(&mut self, value: T) -> std::result::Result<(), Self::Error> { - self.0.send(value).await + self.0 + .send(value) + .await + .map_err(|_| irpc::channel::SendError::ReceiverClosed) } } @@ -483,10 +486,10 @@ pub(crate) mod sink { pub struct Drain; impl Sink for Drain { - type Error = io::Error; + type Error = irpc::channel::SendError; async fn send(&mut self, _offset: T) -> std::result::Result<(), Self::Error> { - io::Result::Ok(()) + Ok(()) } } }