diff --git a/Cargo.lock b/Cargo.lock index 23e65b2d1..2d3988033 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -156,6 +156,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-compression" +version = "0.4.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "977eb15ea9efd848bb8a4a1a2500347ed7f0bf794edf0dc3ddcf439f43d36b23" +dependencies = [ + "compression-codecs", + "compression-core", + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "async-trait" version = "0.1.88" @@ -508,6 +521,28 @@ dependencies = [ "memchr", ] +[[package]] +name = "compression-codecs" +version = "0.4.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "485abf41ac0c8047c07c87c72c8fb3eb5197f6e9d7ded615dfd1a00ae00a0f64" +dependencies = [ + "compression-core", + "lz4", +] + +[[package]] +name = "compression-core" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e47641d3deaf41fb1538ac1f54735925e275eaf3bf4d55c81b137fba797e5cbb" + +[[package]] +name = "concat_const" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60c92cd5ec953d0542f48d2a90a25aa2828ab1c03217c1ca077000f3af15997d" + [[package]] name = "const-oid" version = "0.9.6" @@ -1247,20 +1282,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8a6fe56c0038198998a6f217ca4e7ef3a5e51f46163bd6dd60b5c71ca6c6502" dependencies = [ "async-trait", + "bytes", "cfg-if", "data-encoding", "enum-as-inner", "futures-channel", "futures-io", "futures-util", + "h2", + "http", "idna", "ipnet", "once_cell", "rand 0.9.2", "ring", + "rustls", "thiserror 2.0.12", "tinyvec", "tokio", + "tokio-rustls", "tracing", "url", ] @@ -1280,9 +1320,11 @@ dependencies = [ "parking_lot", "rand 0.9.2", "resolv-conf", + "rustls", "smallvec", "thiserror 2.0.12", "tokio", + "tokio-rustls", "tracing", ] @@ -1659,8 +1701,7 @@ dependencies = [ [[package]] name = "iroh" version = "0.92.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "135ad6b793a5851b9e5435ad36fea63df485f8fd4520a58117e7dc3326a69c15" +source = "git+https://github.com/n0-computer/iroh?branch=main#60d5310dfe42179f6b3a20e38da4e7144008e541" dependencies = [ "aead", "backon", @@ -1721,8 +1762,7 @@ dependencies = [ [[package]] name = "iroh-base" version = "0.92.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04ae51a14c9255a735b1db2d8cf29b875b971e96a5b23e4d0d1ee7d85bf32132" +source = "git+https://github.com/n0-computer/iroh?branch=main#60d5310dfe42179f6b3a20e38da4e7144008e541" dependencies = [ "curve25519-dalek", "data-encoding", @@ -1743,11 +1783,13 @@ version = "0.94.0" dependencies = [ "anyhow", "arrayvec", + "async-compression", "atomic_refcell", "bao-tree", "bytes", "chrono", "clap", + "concat_const", "data-encoding", "derive_more 2.0.1", "futures-lite", @@ -1884,8 +1926,7 @@ dependencies = [ [[package]] name = "iroh-relay" version = "0.92.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "315cb02e660de0de339303296df9a29b27550180bb3979d0753a267649b34a7f" +source = "git+https://github.com/n0-computer/iroh?branch=main#60d5310dfe42179f6b3a20e38da4e7144008e541" dependencies = [ "blake3", "bytes", @@ -2095,6 +2136,25 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" +[[package]] +name = "lz4" +version = "1.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4" +dependencies = [ + "lz4-sys", +] + +[[package]] +name = "lz4-sys" +version = "1.11.1+lz4-1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bd8c0d6c6ed0cd30b3652886bb8711dc4bb01d637a68105a3d5158039b418e6" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "matchers" version = "0.2.0" diff --git a/Cargo.toml b/Cargo.toml index 70eb73a21..976da2cc4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,9 +60,15 @@ tracing-test = "0.2.5" walkdir = "2.5.0" atomic_refcell = "0.1.13" iroh = { version = "0.92", features = ["discovery-local-network"]} +async-compression = { version = "0.4.30", features = ["lz4", "tokio"] } +concat_const = "0.2.0" [features] hide-proto-docs = [] metrics = [] default = ["hide-proto-docs", "fs-store"] fs-store = ["dep:redb", "dep:reflink-copy"] + +[patch.crates-io] +iroh = { git = "https://github.com/n0-computer/iroh", branch = "main" } +iroh-base = { git = "https://github.com/n0-computer/iroh", branch = "main" } diff --git a/examples/compression.rs b/examples/compression.rs new file mode 100644 index 000000000..343209cd8 --- /dev/null +++ b/examples/compression.rs @@ -0,0 +1,229 @@ +/// Example how to use compression with iroh-blobs +/// +/// We create a derived protocol that compresses both requests and responses using lz4 +/// or any other compression algorithm supported by async-compression. +mod common; +use std::{fmt::Debug, path::PathBuf}; + +use anyhow::Result; +use clap::Parser; +use common::setup_logging; +use iroh::protocol::ProtocolHandler; +use iroh_blobs::{ + api::Store, + get::StreamPair, + provider::{ + self, + events::{ClientConnected, EventSender, HasErrorCode}, + handle_stream, + }, + store::mem::MemStore, + ticket::BlobTicket, +}; +use tracing::debug; + +use crate::common::get_or_generate_secret_key; + +#[derive(Debug, Parser)] +#[command(version, about)] +pub enum Args { + /// Limit requests by node id + Provide { + /// Path for files to add. + path: PathBuf, + }, + /// Get a blob. Just for completeness sake. + Get { + /// Ticket for the blob to download + ticket: BlobTicket, + /// Path to save the blob to + #[clap(long)] + target: Option, + }, +} + +trait Compression: Clone + Send + Sync + Debug + 'static { + const ALPN: &'static [u8]; + fn recv_stream( + &self, + stream: iroh::endpoint::RecvStream, + ) -> impl iroh_blobs::util::RecvStream + Sync + 'static; + fn send_stream( + &self, + stream: iroh::endpoint::SendStream, + ) -> impl iroh_blobs::util::SendStream + Sync + 'static; +} + +mod lz4 { + use std::io; + + use async_compression::tokio::{bufread::Lz4Decoder, write::Lz4Encoder}; + use iroh::endpoint::VarInt; + use iroh_blobs::util::{ + AsyncReadRecvStream, AsyncReadRecvStreamExtra, AsyncWriteSendStream, + AsyncWriteSendStreamExtra, + }; + use tokio::io::{AsyncRead, AsyncWrite, BufReader}; + + struct SendStream(Lz4Encoder); + + impl SendStream { + pub fn new(inner: iroh::endpoint::SendStream) -> AsyncWriteSendStream { + AsyncWriteSendStream::new(Self(Lz4Encoder::new(inner))) + } + } + + impl AsyncWriteSendStreamExtra for SendStream { + fn inner(&mut self) -> &mut (impl AsyncWrite + Unpin + Send) { + &mut self.0 + } + + fn reset(&mut self, code: VarInt) -> io::Result<()> { + Ok(self.0.get_mut().reset(code)?) + } + + async fn stopped(&mut self) -> io::Result> { + Ok(self.0.get_mut().stopped().await?) + } + + fn id(&self) -> u64 { + self.0.get_ref().id().index() + } + } + + struct RecvStream(Lz4Decoder>); + + impl RecvStream { + pub fn new(inner: iroh::endpoint::RecvStream) -> AsyncReadRecvStream { + AsyncReadRecvStream::new(Self(Lz4Decoder::new(BufReader::new(inner)))) + } + } + + impl AsyncReadRecvStreamExtra for RecvStream { + fn inner(&mut self) -> &mut (impl AsyncRead + Unpin + Send) { + &mut self.0 + } + + fn stop(&mut self, code: VarInt) -> io::Result<()> { + Ok(self.0.get_mut().get_mut().stop(code)?) + } + + fn id(&self) -> u64 { + self.0.get_ref().get_ref().id().index() + } + } + + #[derive(Debug, Clone)] + pub struct Compression; + + impl super::Compression for Compression { + const ALPN: &[u8] = concat_const::concat_bytes!(b"lz4/", iroh_blobs::ALPN); + fn recv_stream( + &self, + stream: iroh::endpoint::RecvStream, + ) -> impl iroh_blobs::util::RecvStream + Sync + 'static { + RecvStream::new(stream) + } + fn send_stream( + &self, + stream: iroh::endpoint::SendStream, + ) -> impl iroh_blobs::util::SendStream + Sync + 'static { + SendStream::new(stream) + } + } +} + +#[derive(Debug, Clone)] +struct CompressedBlobsProtocol { + store: Store, + events: EventSender, + compression: C, +} + +impl CompressedBlobsProtocol { + fn new(store: &Store, events: EventSender, compression: C) -> Self { + Self { + store: store.clone(), + events, + compression, + } + } +} + +impl ProtocolHandler for CompressedBlobsProtocol { + async fn accept( + &self, + connection: iroh::endpoint::Connection, + ) -> std::result::Result<(), iroh::protocol::AcceptError> { + let connection_id = connection.stable_id() as u64; + if let Err(cause) = self + .events + .client_connected(|| ClientConnected { + connection_id, + node_id: connection.remote_node_id().ok(), + }) + .await + { + connection.close(cause.code(), cause.reason()); + debug!("closing connection: {cause}"); + return Ok(()); + } + while let Ok((send, recv)) = connection.accept_bi().await { + let send = self.compression.send_stream(send); + let recv = self.compression.recv_stream(recv); + let store = self.store.clone(); + let pair = provider::StreamPair::new(connection_id, recv, send, self.events.clone()); + tokio::spawn(handle_stream(pair, store)); + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + setup_logging(); + let args = Args::parse(); + let secret = get_or_generate_secret_key()?; + let endpoint = iroh::Endpoint::builder() + .secret_key(secret) + .discovery_n0() + .bind() + .await?; + let compression = lz4::Compression; + match args { + Args::Provide { path } => { + let store = MemStore::new(); + let tag = store.add_path(path).await?; + let blobs = CompressedBlobsProtocol::new(&store, EventSender::DEFAULT, compression); + let router = iroh::protocol::Router::builder(endpoint.clone()) + .accept(lz4::Compression::ALPN, blobs) + .spawn(); + let ticket = BlobTicket::new(endpoint.node_id().into(), tag.hash, tag.format); + println!("Serving blob with hash {}", tag.hash); + println!("Ticket: {ticket}"); + println!("Node is running. Press Ctrl-C to exit."); + tokio::signal::ctrl_c().await?; + println!("Shutting down."); + router.shutdown().await?; + } + Args::Get { ticket, target } => { + let store = MemStore::new(); + let conn = endpoint + .connect(ticket.node_addr().clone(), lz4::Compression::ALPN) + .await?; + let connection_id = conn.stable_id() as u64; + let (send, recv) = conn.open_bi().await?; + let send = compression.send_stream(send); + let recv = compression.recv_stream(recv); + let sp = StreamPair::new(connection_id, recv, send); + let _stats = store.remote().fetch(sp, ticket.hash_and_format()).await?; + if let Some(target) = target { + let size = store.export(ticket.hash(), &target).await?; + println!("Wrote {} bytes to {}", size, target.display()); + } else { + println!("Hash: {}", ticket.hash()); + } + } + } + Ok(()) +} diff --git a/src/api/blobs.rs b/src/api/blobs.rs index 897e0371c..6e8bbc3c3 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -23,14 +23,12 @@ use bao_tree::{ }; use bytes::Bytes; use genawaiter::sync::Gen; -use iroh_io::{AsyncStreamReader, TokioStreamReader}; +use iroh_io::AsyncStreamWriter; use irpc::channel::{mpsc, oneshot}; use n0_future::{future, stream, Stream, StreamExt}; -use quinn::SendStream; use range_collections::{range_set::RangeSetRange, RangeSet2}; use ref_cast::RefCast; use serde::{Deserialize, Serialize}; -use tokio::io::AsyncWriteExt; use tracing::trace; mod reader; pub use reader::BlobReader; @@ -59,7 +57,7 @@ use crate::{ api::proto::{BatchRequest, ImportByteStreamUpdate}, provider::events::ClientResult, store::IROH_BLOCK_SIZE, - util::temp_tag::TempTag, + util::{temp_tag::TempTag, RecvStreamAsyncStreamReader}, BlobFormat, Hash, HashAndFormat, }; @@ -433,13 +431,18 @@ impl Blobs { } #[cfg_attr(feature = "hide-proto-docs", doc(hidden))] - async fn import_bao_reader( + pub async fn import_bao_reader( &self, hash: Hash, ranges: ChunkRanges, mut reader: R, ) -> RequestResult { - let size = u64::from_le_bytes(reader.read::<8>().await.map_err(super::Error::other)?); + let mut size = [0; 8]; + reader + .recv_exact(&mut size) + .await + .map_err(super::Error::other)?; + let size = u64::from_le_bytes(size); let Some(size) = NonZeroU64::new(size) else { return if hash == Hash::EMPTY { Ok(reader) @@ -448,7 +451,12 @@ impl Blobs { }; }; let tree = BaoTree::new(size.get(), IROH_BLOCK_SIZE); - let mut decoder = ResponseDecoder::new(hash.into(), ranges, tree, reader); + let mut decoder = ResponseDecoder::new( + hash.into(), + ranges, + tree, + RecvStreamAsyncStreamReader::new(reader), + ); let options = ImportBaoOptions { hash, size }; let handle = self.import_bao_with_opts(options, 32).await?; let driver = async move { @@ -467,19 +475,7 @@ impl Blobs { let fut = async move { handle.rx.await.map_err(io::Error::other)? }; let (reader, res) = tokio::join!(driver, fut); res?; - Ok(reader?) - } - - #[cfg_attr(feature = "hide-proto-docs", doc(hidden))] - pub async fn import_bao_quinn( - &self, - hash: Hash, - ranges: ChunkRanges, - stream: &mut iroh::endpoint::RecvStream, - ) -> RequestResult<()> { - let reader = TokioStreamReader::new(stream); - self.import_bao_reader(hash, ranges, reader).await?; - Ok(()) + Ok(reader?.into_inner()) } #[cfg_attr(feature = "hide-proto-docs", doc(hidden))] @@ -1061,24 +1057,21 @@ impl ExportBaoProgress { Ok(data) } - pub async fn write_quinn(self, target: &mut quinn::SendStream) -> super::ExportBaoResult<()> { + pub async fn write(self, target: &mut W) -> super::ExportBaoResult<()> { let mut rx = self.inner.await?; while let Some(item) = rx.recv().await? { match item { EncodedItem::Size(size) => { - target.write_u64_le(size).await?; + target.write(&size.to_le_bytes()).await?; } EncodedItem::Parent(parent) => { let mut data = vec![0u8; 64]; data[..32].copy_from_slice(parent.pair.0.as_bytes()); data[32..].copy_from_slice(parent.pair.1.as_bytes()); - target.write_all(&data).await.map_err(io::Error::from)?; + target.write(&data).await?; } EncodedItem::Leaf(leaf) => { - target - .write_chunk(leaf.data) - .await - .map_err(io::Error::from)?; + target.write_bytes(leaf.data).await?; } EncodedItem::Done => break, EncodedItem::Error(cause) => return Err(cause.into()), @@ -1088,9 +1081,9 @@ impl ExportBaoProgress { } /// Write quinn variant that also feeds a progress writer. - pub(crate) async fn write_quinn_with_progress( + pub(crate) async fn write_with_progress( self, - writer: &mut SendStream, + writer: &mut W, progress: &mut impl WriteProgress, hash: &Hash, index: u64, @@ -1100,22 +1093,19 @@ impl ExportBaoProgress { match item { EncodedItem::Size(size) => { progress.send_transfer_started(index, hash, size).await; - writer.write_u64_le(size).await?; + writer.send(&size.to_le_bytes()).await?; progress.log_other_write(8); } EncodedItem::Parent(parent) => { - let mut data = vec![0u8; 64]; + let mut data = [0u8; 64]; data[..32].copy_from_slice(parent.pair.0.as_bytes()); data[32..].copy_from_slice(parent.pair.1.as_bytes()); - writer.write_all(&data).await.map_err(io::Error::from)?; + writer.send(&data).await?; progress.log_other_write(64); } EncodedItem::Leaf(leaf) => { let len = leaf.data.len(); - writer - .write_chunk(leaf.data) - .await - .map_err(io::Error::from)?; + writer.send_bytes(leaf.data).await?; progress .notify_payload_write(index, leaf.offset, len) .await?; diff --git a/src/api/downloader.rs b/src/api/downloader.rs index bf78bf793..82cef8393 100644 --- a/src/api/downloader.rs +++ b/src/api/downloader.rs @@ -456,7 +456,7 @@ async fn execute_get( }; match remote .execute_get_sink( - &conn, + conn.clone(), local.missing(), (&mut progress).with_map(move |x| DownloadProgressItem::Progress(x + local_bytes)), ) @@ -564,9 +564,7 @@ mod tests { .download(request, Shuffled::new(vec![node1_id, node2_id])) .stream() .await?; - while let Some(item) = progress.next().await { - println!("Got item: {item:?}"); - } + while progress.next().await.is_some() {} assert_eq!(store3.get_bytes(tt1.hash).await?.deref(), b"hello world"); assert_eq!(store3.get_bytes(tt2.hash).await?.deref(), b"hello world 2"); Ok(()) @@ -609,9 +607,7 @@ mod tests { )) .stream() .await?; - while let Some(item) = progress.next().await { - println!("Got item: {item:?}"); - } + while progress.next().await.is_some() {} } if false { let conn = r3.endpoint().connect(node1_addr, crate::ALPN).await?; @@ -673,9 +669,7 @@ mod tests { )) .stream() .await?; - while let Some(item) = progress.next().await { - println!("Got item: {item:?}"); - } + while progress.next().await.is_some() {} Ok(()) } } diff --git a/src/api/remote.rs b/src/api/remote.rs index dcfbc4fb4..cf73096a1 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -1,25 +1,54 @@ //! API for downloading blobs from a single remote node. //! //! The entry point is the [`Remote`] struct. +use std::{ + collections::BTreeMap, + future::{Future, IntoFuture}, + num::NonZeroU64, + sync::Arc, +}; + +use bao_tree::{ + io::{BaoContentItem, Leaf}, + ChunkNum, ChunkRanges, +}; use genawaiter::sync::{Co, Gen}; -use iroh::endpoint::SendStream; +use iroh::endpoint::Connection; use irpc::util::{AsyncReadVarintExt, WriteVarintExt}; use n0_future::{io, Stream, StreamExt}; use n0_snafu::SpanTrace; use nested_enum_utils::common_fields; use ref_cast::RefCast; -use snafu::{Backtrace, IntoError, Snafu}; +use snafu::{Backtrace, IntoError, ResultExt, Snafu}; +use tracing::{debug, trace}; use super::blobs::{Bitfield, ExportBaoOptions}; use crate::{ - api::{blobs::WriteProgress, ApiClient}, - get::{fsm::DecodeError, BadRequestSnafu, GetError, GetResult, LocalFailureSnafu, Stats}, + api::{ + self, + blobs::{Blobs, WriteProgress}, + ApiClient, Store, + }, + get::{ + fsm::{ + AtBlobHeader, AtConnected, AtEndBlob, BlobContentNext, ConnectedNext, DecodeError, + EndBlobNext, + }, + get_error::{BadRequestSnafu, LocalFailureSnafu}, + GetError, GetResult, Stats, StreamPair, + }, + hashseq::{HashSeq, HashSeqIter}, protocol::{ - GetManyRequest, ObserveItem, ObserveRequest, PushRequest, Request, RequestType, - MAX_MESSAGE_SIZE, + ChunkRangesSeq, GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, + Request, RequestType, MAX_MESSAGE_SIZE, }, provider::events::{ClientResult, ProgressError}, - util::sink::{Sink, TokioMpscSenderSink}, + store::IROH_BLOCK_SIZE, + util::{ + sink::{Sink, TokioMpscSenderSink}, + RecvStream, SendStream, + }, + Hash, HashAndFormat, }; /// API to compute request and to download from remote nodes. @@ -95,8 +124,7 @@ impl GetProgress { pub async fn complete(self) -> GetResult { just_result(self.stream()).await.unwrap_or_else(|| { - Err(LocalFailureSnafu - .into_error(anyhow::anyhow!("stream closed without result").into())) + Err(LocalFailureSnafu.into_error(anyhow::anyhow!("stream closed without result"))) }) } } @@ -473,7 +501,7 @@ impl Remote { pub fn fetch( &self, - conn: impl GetConnection + Send + 'static, + sp: impl GetStreamPair + 'static, content: impl Into, ) -> GetProgress { let content = content.into(); @@ -482,7 +510,7 @@ impl Remote { let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress); let this = self.clone(); let fut = async move { - let res = this.fetch_sink(conn, content, sink).await.into(); + let res = this.fetch_sink(sp, content, sink).await.into(); tx2.send(res).await.ok(); }; GetProgress { @@ -500,7 +528,7 @@ impl Remote { /// This will return the stats of the download. pub(crate) async fn fetch_sink( &self, - mut conn: impl GetConnection, + sp: impl GetStreamPair, content: impl Into, progress: impl Sink, ) -> GetResult { @@ -508,16 +536,12 @@ impl Remote { let local = self .local(content) .await - .map_err(|e| LocalFailureSnafu.into_error(e.into()))?; + .map_err(|e: anyhow::Error| LocalFailureSnafu.into_error(e))?; if local.is_complete() { return Ok(Default::default()); } let request = local.missing(); - let conn = conn - .connection() - .await - .map_err(|e| LocalFailureSnafu.into_error(e.into()))?; - let stats = self.execute_get_sink(&conn, request, progress).await?; + let stats = self.execute_get_sink(sp, request, progress).await?; Ok(stats) } @@ -593,7 +617,7 @@ impl Remote { if !root_ranges.is_empty() { self.store() .export_bao(root, root_ranges.clone()) - .write_quinn_with_progress(&mut send, &mut context, &root, 0) + .write_with_progress(&mut send, &mut context, &root, 0) .await?; } if request.ranges.is_blob() { @@ -609,12 +633,7 @@ impl Remote { if !child_ranges.is_empty() { self.store() .export_bao(child_hash, child_ranges.clone()) - .write_quinn_with_progress( - &mut send, - &mut context, - &child_hash, - (child + 1) as u64, - ) + .write_with_progress(&mut send, &mut context, &child_hash, (child + 1) as u64) .await?; } } @@ -622,17 +641,21 @@ impl Remote { Ok(Default::default()) } - pub fn execute_get(&self, conn: Connection, request: GetRequest) -> GetProgress { + pub fn execute_get(&self, conn: impl GetStreamPair, request: GetRequest) -> GetProgress { self.execute_get_with_opts(conn, request) } - pub fn execute_get_with_opts(&self, conn: Connection, request: GetRequest) -> GetProgress { + pub fn execute_get_with_opts( + &self, + conn: impl GetStreamPair, + request: GetRequest, + ) -> GetProgress { let (tx, rx) = tokio::sync::mpsc::channel(64); let tx2 = tx.clone(); let sink = TokioMpscSenderSink(tx).with_map(GetProgressItem::Progress); let this = self.clone(); let fut = async move { - let res = this.execute_get_sink(&conn, request, sink).await.into(); + let res = this.execute_get_sink(conn, request, sink).await.into(); tx2.send(res).await.ok(); }; GetProgress { @@ -651,16 +674,19 @@ impl Remote { /// This will return the stats of the download. pub(crate) async fn execute_get_sink( &self, - conn: &Connection, + conn: impl GetStreamPair, request: GetRequest, mut progress: impl Sink, ) -> GetResult { let store = self.store(); let root = request.hash; + let conn = conn.open_stream_pair().await.map_err(|e| { + LocalFailureSnafu.into_error(anyhow::anyhow!("failed to open stream pair: {e}")) + })?; // I am cloning the connection, but it's fine because the original connection or ConnectionRef stays alive // for the duration of the operation. - let start = crate::get::fsm::start(conn.clone(), request, Default::default()); - let connected = start.next().await?; + let connected = + AtConnected::new(conn.t0, conn.recv, conn.send, request, Default::default()); trace!("Getting header"); // read the header let next_child = match connected.next().await? { @@ -685,7 +711,7 @@ impl Remote { .await .map_err(|e| LocalFailureSnafu.into_error(e.into()))?, ) - .map_err(|source| BadRequestSnafu.into_error(source.into()))?; + .context(BadRequestSnafu)?; // let mut hash_seq = LazyHashSeq::new(store.blobs().clone(), root); loop { let at_start_child = match next_child { @@ -755,7 +781,6 @@ impl Remote { Err(at_closing) => break at_closing, }; let offset = at_start_child.offset(); - println!("offset {offset}"); let Some(hash) = hash_seq.get(offset as usize) else { break at_start_child.finish(); }; @@ -820,52 +845,25 @@ pub enum ExecuteError { }, } -use std::{ - collections::BTreeMap, - future::{Future, IntoFuture}, - num::NonZeroU64, - sync::Arc, -}; - -use bao_tree::{ - io::{BaoContentItem, Leaf}, - ChunkNum, ChunkRanges, -}; -use iroh::endpoint::Connection; -use tracing::{debug, trace}; - -use crate::{ - api::{self, blobs::Blobs, Store}, - get::fsm::{AtBlobHeader, AtEndBlob, BlobContentNext, ConnectedNext, EndBlobNext}, - hashseq::{HashSeq, HashSeqIter}, - protocol::{ChunkRangesSeq, GetRequest}, - store::IROH_BLOCK_SIZE, - Hash, HashAndFormat, -}; - -/// Trait to lazily get a connection -pub trait GetConnection { - fn connection(&mut self) - -> impl Future> + Send + '_; +pub trait GetStreamPair: Send + 'static { + fn open_stream_pair( + self, + ) -> impl Future>> + Send + 'static; } -/// If we already have a connection, the impl is trivial -impl GetConnection for Connection { - fn connection( - &mut self, - ) -> impl Future> + Send + '_ { - let conn = self.clone(); - async { Ok(conn) } +impl GetStreamPair for StreamPair { + async fn open_stream_pair(self) -> io::Result> { + Ok(self) } } -/// If we already have a connection, the impl is trivial -impl GetConnection for &Connection { - fn connection( - &mut self, - ) -> impl Future> + Send + '_ { - let conn = self.clone(); - async { Ok(conn) } +impl GetStreamPair for Connection { + async fn open_stream_pair( + self, + ) -> io::Result> { + let connection_id = self.stable_id() as u64; + let (send, recv) = self.open_bi().await?; + Ok(StreamPair::new(connection_id, recv, send)) } } @@ -873,12 +871,12 @@ fn get_buffer_size(size: NonZeroU64) -> usize { (size.get() / (IROH_BLOCK_SIZE.bytes() as u64) + 2).min(64) as usize } -async fn get_blob_ranges_impl( - header: AtBlobHeader, +async fn get_blob_ranges_impl( + header: AtBlobHeader, hash: Hash, store: &Store, mut progress: impl Sink, -) -> GetResult { +) -> GetResult> { let (mut content, size) = header.next().await?; let Some(size) = NonZeroU64::new(size) else { return if hash == Hash::EMPTY { @@ -915,8 +913,7 @@ async fn get_blob_ranges_impl( }; let complete = async move { handle.rx.await.map_err(|e| { - LocalFailureSnafu - .into_error(anyhow::anyhow!("error reading from import stream: {e}").into()) + LocalFailureSnafu.into_error(anyhow::anyhow!("error reading from import stream: {e}")) }) }; let (_, end) = tokio::try_join!(complete, write)?; @@ -1017,20 +1014,23 @@ impl LazyHashSeq { async fn write_push_request( request: PushRequest, - stream: &mut SendStream, + stream: &mut impl SendStream, ) -> anyhow::Result { let mut request_bytes = Vec::new(); request_bytes.push(RequestType::Push as u8); request_bytes.write_length_prefixed(&request).unwrap(); - stream.write_all(&request_bytes).await?; + stream.send_bytes(request_bytes.into()).await?; Ok(request) } -async fn write_observe_request(request: ObserveRequest, stream: &mut SendStream) -> io::Result<()> { +async fn write_observe_request( + request: ObserveRequest, + stream: &mut impl SendStream, +) -> io::Result<()> { let request = Request::Observe(request); let request_bytes = postcard::to_allocvec(&request) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - stream.write_all(&request_bytes).await?; + stream.send_bytes(request_bytes.into()).await?; Ok(()) } diff --git a/src/get.rs b/src/get.rs index 049ef4855..d13092a85 100644 --- a/src/get.rs +++ b/src/get.rs @@ -17,7 +17,6 @@ //! //! [iroh]: https://docs.rs/iroh use std::{ - error::Error, fmt::{self, Debug}, time::{Duration, Instant}, }; @@ -25,22 +24,44 @@ use std::{ use anyhow::Result; use bao_tree::{io::fsm::BaoContentItem, ChunkNum}; use fsm::RequestCounters; -use iroh::endpoint::{self, RecvStream, SendStream}; -use iroh_io::TokioStreamReader; use n0_snafu::SpanTrace; use nested_enum_utils::common_fields; use serde::{Deserialize, Serialize}; use snafu::{Backtrace, IntoError, ResultExt, Snafu}; use tracing::{debug, error}; -use crate::{protocol::ChunkRangesSeq, store::IROH_BLOCK_SIZE, Hash}; +use crate::{ + protocol::ChunkRangesSeq, + store::IROH_BLOCK_SIZE, + util::{RecvStream, SendStream}, + Hash, +}; mod error; pub mod request; -pub(crate) use error::{BadRequestSnafu, LocalFailureSnafu}; +pub(crate) use error::get_error; pub use error::{GetError, GetResult}; -type WrappedRecvStream = TokioStreamReader; +type DefaultReader = iroh::endpoint::RecvStream; +type DefaultWriter = iroh::endpoint::SendStream; + +pub struct StreamPair { + pub connection_id: u64, + pub t0: Instant, + pub recv: R, + pub send: W, +} + +impl StreamPair { + pub fn new(connection_id: u64, recv: R, send: W) -> Self { + Self { + t0: Instant::now(), + recv, + send, + connection_id, + } + } +} /// Stats about the transfer. #[derive( @@ -96,14 +117,15 @@ pub mod fsm { }; use derive_more::From; use iroh::endpoint::Connection; - use iroh_io::{AsyncSliceWriter, AsyncStreamReader, TokioStreamReader}; + use iroh_io::AsyncSliceWriter; use super::*; use crate::{ - get::error::BadRequestSnafu, + get::get_error::BadRequestSnafu, protocol::{ GetManyRequest, GetRequest, NonEmptyRequestRangeSpecIter, Request, MAX_MESSAGE_SIZE, }, + util::{RecvStream, RecvStreamAsyncStreamReader, SendStream}, }; self_cell::self_cell! { @@ -130,16 +152,20 @@ pub mod fsm { counters: RequestCounters, ) -> std::result::Result, GetError> { let start = Instant::now(); - let (mut writer, reader) = connection.open_bi().await?; + let (mut writer, reader) = connection + .open_bi() + .await + .map_err(|e| OpenSnafu.into_error(e.into()))?; let request = Request::GetMany(request); let request_bytes = postcard::to_stdvec(&request) .map_err(|source| BadRequestSnafu.into_error(source.into()))?; - writer.write_all(&request_bytes).await?; - writer.finish()?; + writer + .send_bytes(request_bytes.into()) + .await + .context(connected_next_error::WriteSnafu)?; let Request::GetMany(request) = request else { unreachable!(); }; - let reader = TokioStreamReader::new(reader); let mut ranges_iter = RangesIter::new(request.ranges.clone()); let first_item = ranges_iter.next(); let misc = Box::new(Misc { @@ -214,10 +240,13 @@ pub mod fsm { } /// Initiate a new bidi stream to use for the get response - pub async fn next(self) -> Result { + pub async fn next(self) -> Result { let start = Instant::now(); - let (writer, reader) = self.connection.open_bi().await?; - let reader = TokioStreamReader::new(reader); + let (writer, reader) = self + .connection + .open_bi() + .await + .map_err(|e| OpenSnafu.into_error(e.into()))?; Ok(AtConnected { start, reader, @@ -228,25 +257,38 @@ pub mod fsm { } } + /// Error that you can get from [`AtConnected::next`] + #[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: SpanTrace, + })] + #[allow(missing_docs)] + #[derive(Debug, Snafu)] + #[non_exhaustive] + pub enum InitialNextError { + Open { source: io::Error }, + } + /// State of the get response machine after the handshake has been sent #[derive(Debug)] - pub struct AtConnected { + pub struct AtConnected { start: Instant, - reader: WrappedRecvStream, - writer: SendStream, + reader: R, + writer: W, request: GetRequest, counters: RequestCounters, } /// Possible next states after the handshake has been sent #[derive(Debug, From)] - pub enum ConnectedNext { + pub enum ConnectedNext { /// First response is either a collection or a single blob - StartRoot(AtStartRoot), + StartRoot(AtStartRoot), /// First response is a child - StartChild(AtStartChild), + StartChild(AtStartChild), /// Request is empty - Closing(AtClosing), + Closing(AtClosing), } /// Error that you can get from [`AtConnected::next`] @@ -257,6 +299,7 @@ pub mod fsm { })] #[allow(missing_docs)] #[derive(Debug, Snafu)] + #[snafu(module)] #[non_exhaustive] pub enum ConnectedNextError { /// Error when serializing the request @@ -267,23 +310,33 @@ pub mod fsm { RequestTooBig {}, /// Error when writing the request to the [`SendStream`]. #[snafu(display("write: {source}"))] - Write { source: quinn::WriteError }, - /// Quic connection is closed. - #[snafu(display("closed"))] - Closed { source: quinn::ClosedStream }, - /// A generic io error - #[snafu(transparent)] - Io { source: io::Error }, + Write { source: io::Error }, } - impl AtConnected { + impl AtConnected { + pub fn new( + start: Instant, + reader: R, + writer: W, + request: GetRequest, + counters: RequestCounters, + ) -> Self { + Self { + start, + reader, + writer, + request, + counters, + } + } + /// Send the request and move to the next state /// /// The next state will be either `StartRoot` or `StartChild` depending on whether /// the request requests part of the collection or not. /// /// If the request is empty, this can also move directly to `Finished`. - pub async fn next(self) -> Result { + pub async fn next(self) -> Result, ConnectedNextError> { let Self { start, reader, @@ -295,23 +348,32 @@ pub mod fsm { counters.other_bytes_written += { debug!("sending request"); let wrapped = Request::Get(request); - let request_bytes = postcard::to_stdvec(&wrapped).context(PostcardSerSnafu)?; + let request_bytes = postcard::to_stdvec(&wrapped) + .context(connected_next_error::PostcardSerSnafu)?; let Request::Get(x) = wrapped else { unreachable!(); }; request = x; if request_bytes.len() > MAX_MESSAGE_SIZE { - return Err(RequestTooBigSnafu.build()); + return Err(connected_next_error::RequestTooBigSnafu.build()); } // write the request itself - writer.write_all(&request_bytes).await.context(WriteSnafu)?; - request_bytes.len() as u64 + let len = request_bytes.len() as u64; + writer + .send_bytes(request_bytes.into()) + .await + .context(connected_next_error::WriteSnafu)?; + writer + .sync() + .await + .context(connected_next_error::WriteSnafu)?; + len }; // 2. Finish writing before expecting a response - writer.finish().context(ClosedSnafu)?; + drop(writer); let hash = request.hash; let ranges_iter = RangesIter::new(request.ranges); @@ -348,23 +410,23 @@ pub mod fsm { /// State of the get response when we start reading a collection #[derive(Debug)] - pub struct AtStartRoot { + pub struct AtStartRoot { ranges: ChunkRanges, - reader: TokioStreamReader, + reader: R, misc: Box, hash: Hash, } /// State of the get response when we start reading a child #[derive(Debug)] - pub struct AtStartChild { + pub struct AtStartChild { ranges: ChunkRanges, - reader: TokioStreamReader, + reader: R, misc: Box, offset: u64, } - impl AtStartChild { + impl AtStartChild { /// The offset of the child we are currently reading /// /// This must be used to determine the hash needed to call next. @@ -382,7 +444,7 @@ pub mod fsm { /// Go into the next state, reading the header /// /// This requires passing in the hash of the child for validation - pub fn next(self, hash: Hash) -> AtBlobHeader { + pub fn next(self, hash: Hash) -> AtBlobHeader { AtBlobHeader { reader: self.reader, ranges: self.ranges, @@ -396,12 +458,12 @@ pub mod fsm { /// This is used if you know that there are no more children from having /// read the collection, or when you want to stop reading the response /// early. - pub fn finish(self) -> AtClosing { + pub fn finish(self) -> AtClosing { AtClosing::new(self.misc, self.reader, false) } } - impl AtStartRoot { + impl AtStartRoot { /// The ranges we have requested for the child pub fn ranges(&self) -> &ChunkRanges { &self.ranges @@ -415,7 +477,7 @@ pub mod fsm { /// Go into the next state, reading the header /// /// For the collection we already know the hash, since it was part of the request - pub fn next(self) -> AtBlobHeader { + pub fn next(self) -> AtBlobHeader { AtBlobHeader { reader: self.reader, ranges: self.ranges, @@ -425,16 +487,16 @@ pub mod fsm { } /// Finish the get response without reading further - pub fn finish(self) -> AtClosing { + pub fn finish(self) -> AtClosing { AtClosing::new(self.misc, self.reader, false) } } /// State before reading a size header #[derive(Debug)] - pub struct AtBlobHeader { + pub struct AtBlobHeader { ranges: ChunkRanges, - reader: TokioStreamReader, + reader: R, misc: Box, hash: Hash, } @@ -447,18 +509,16 @@ pub mod fsm { })] #[non_exhaustive] #[derive(Debug, Snafu)] + #[snafu(module)] pub enum AtBlobHeaderNextError { /// Eof when reading the size header /// /// This indicates that the provider does not have the requested data. #[snafu(display("not found"))] NotFound {}, - /// Quinn read error when reading the size header - #[snafu(display("read: {source}"))] - EndpointRead { source: endpoint::ReadError }, /// Generic io error #[snafu(display("io: {source}"))] - Io { source: io::Error }, + Read { source: io::Error }, } impl From for io::Error { @@ -467,25 +527,20 @@ pub mod fsm { AtBlobHeaderNextError::NotFound { .. } => { io::Error::new(io::ErrorKind::UnexpectedEof, cause) } - AtBlobHeaderNextError::EndpointRead { source, .. } => source.into(), - AtBlobHeaderNextError::Io { source, .. } => source, + AtBlobHeaderNextError::Read { source, .. } => source, } } } - impl AtBlobHeader { + impl AtBlobHeader { /// Read the size header, returning it and going into the `Content` state. - pub async fn next(mut self) -> Result<(AtBlobContent, u64), AtBlobHeaderNextError> { - let size = self.reader.read::<8>().await.map_err(|cause| { + pub async fn next(mut self) -> Result<(AtBlobContent, u64), AtBlobHeaderNextError> { + let mut size = [0; 8]; + self.reader.recv_exact(&mut size).await.map_err(|cause| { if cause.kind() == io::ErrorKind::UnexpectedEof { - NotFoundSnafu.build() - } else if let Some(e) = cause - .get_ref() - .and_then(|x| x.downcast_ref::()) - { - EndpointReadSnafu.into_error(e.clone()) + at_blob_header_next_error::NotFoundSnafu.build() } else { - IoSnafu.into_error(cause) + at_blob_header_next_error::ReadSnafu.into_error(cause) } })?; self.misc.other_bytes_read += 8; @@ -494,7 +549,7 @@ pub mod fsm { self.hash.into(), self.ranges, BaoTree::new(size, IROH_BLOCK_SIZE), - self.reader, + RecvStreamAsyncStreamReader::new(self.reader), ); Ok(( AtBlobContent { @@ -506,7 +561,7 @@ pub mod fsm { } /// Drain the response and throw away the result - pub async fn drain(self) -> result::Result { + pub async fn drain(self) -> result::Result, DecodeError> { let (content, _size) = self.next().await?; content.drain().await } @@ -517,7 +572,7 @@ pub mod fsm { /// concatenate the ranges that were requested. pub async fn concatenate_into_vec( self, - ) -> result::Result<(AtEndBlob, Vec), DecodeError> { + ) -> result::Result<(AtEndBlob, Vec), DecodeError> { let (content, _size) = self.next().await?; content.concatenate_into_vec().await } @@ -526,7 +581,7 @@ pub mod fsm { pub async fn write_all( self, data: D, - ) -> result::Result { + ) -> result::Result, DecodeError> { let (content, _size) = self.next().await?; let res = content.write_all(data).await?; Ok(res) @@ -540,7 +595,7 @@ pub mod fsm { self, outboard: Option, data: D, - ) -> result::Result + ) -> result::Result, DecodeError> where D: AsyncSliceWriter, O: OutboardMut, @@ -568,8 +623,8 @@ pub mod fsm { /// State while we are reading content #[derive(Debug)] - pub struct AtBlobContent { - stream: ResponseDecoder, + pub struct AtBlobContent { + stream: ResponseDecoder>, misc: Box, } @@ -603,6 +658,7 @@ pub mod fsm { })] #[non_exhaustive] #[derive(Debug, Snafu)] + #[snafu(module)] pub enum DecodeError { /// A chunk was not found or invalid, so the provider stopped sending data #[snafu(display("not found"))] @@ -621,24 +677,25 @@ pub mod fsm { LeafHashMismatch { num: ChunkNum }, /// Error when reading from the stream #[snafu(display("read: {source}"))] - Read { source: endpoint::ReadError }, + Read { source: io::Error }, /// A generic io error #[snafu(display("io: {source}"))] - DecodeIo { source: io::Error }, + Write { source: io::Error }, } impl DecodeError { pub(crate) fn leaf_hash_mismatch(num: ChunkNum) -> Self { - LeafHashMismatchSnafu { num }.build() + decode_error::LeafHashMismatchSnafu { num }.build() } } impl From for DecodeError { fn from(cause: AtBlobHeaderNextError) -> Self { match cause { - AtBlobHeaderNextError::NotFound { .. } => ChunkNotFoundSnafu.build(), - AtBlobHeaderNextError::EndpointRead { source, .. } => ReadSnafu.into_error(source), - AtBlobHeaderNextError::Io { source, .. } => DecodeIoSnafu.into_error(source), + AtBlobHeaderNextError::NotFound { .. } => decode_error::ChunkNotFoundSnafu.build(), + AtBlobHeaderNextError::Read { source, .. } => { + decode_error::ReadSnafu.into_error(source) + } } } } @@ -652,59 +709,50 @@ pub mod fsm { DecodeError::LeafNotFound { .. } => { io::Error::new(io::ErrorKind::UnexpectedEof, cause) } - DecodeError::Read { source, .. } => source.into(), - DecodeError::DecodeIo { source, .. } => source, + DecodeError::Read { source, .. } => source, + DecodeError::Write { source, .. } => source, _ => io::Error::other(cause), } } } - impl From for DecodeError { - fn from(value: io::Error) -> Self { - DecodeIoSnafu.into_error(value) - } - } - impl From for DecodeError { fn from(value: bao_tree::io::DecodeError) -> Self { match value { - bao_tree::io::DecodeError::ParentNotFound(x) => { - ParentNotFoundSnafu { node: x }.build() + bao_tree::io::DecodeError::ParentNotFound(node) => { + decode_error::ParentNotFoundSnafu { node }.build() + } + bao_tree::io::DecodeError::LeafNotFound(num) => { + decode_error::LeafNotFoundSnafu { num }.build() } - bao_tree::io::DecodeError::LeafNotFound(x) => LeafNotFoundSnafu { num: x }.build(), bao_tree::io::DecodeError::ParentHashMismatch(node) => { - ParentHashMismatchSnafu { node }.build() + decode_error::ParentHashMismatchSnafu { node }.build() } - bao_tree::io::DecodeError::LeafHashMismatch(chunk) => { - LeafHashMismatchSnafu { num: chunk }.build() - } - bao_tree::io::DecodeError::Io(cause) => { - if let Some(inner) = cause.get_ref() { - if let Some(e) = inner.downcast_ref::() { - ReadSnafu.into_error(e.clone()) - } else { - DecodeIoSnafu.into_error(cause) - } - } else { - DecodeIoSnafu.into_error(cause) - } + bao_tree::io::DecodeError::LeafHashMismatch(num) => { + decode_error::LeafHashMismatchSnafu { num }.build() } + bao_tree::io::DecodeError::Io(cause) => decode_error::ReadSnafu.into_error(cause), } } } /// The next state after reading a content item #[derive(Debug, From)] - pub enum BlobContentNext { + pub enum BlobContentNext { /// We expect more content - More((AtBlobContent, result::Result)), + More( + ( + AtBlobContent, + result::Result, + ), + ), /// We are done with this blob - Done(AtEndBlob), + Done(AtEndBlob), } - impl AtBlobContent { + impl AtBlobContent { /// Read the next item, either content, an error, or the end of the blob - pub async fn next(self) -> BlobContentNext { + pub async fn next(self) -> BlobContentNext { match self.stream.next().await { ResponseDecoderNext::More((stream, res)) => { let mut next = Self { stream, ..self }; @@ -721,7 +769,7 @@ pub mod fsm { BlobContentNext::More((next, res)) } ResponseDecoderNext::Done(stream) => BlobContentNext::Done(AtEndBlob { - stream, + stream: stream.into_inner(), misc: self.misc, }), } @@ -751,7 +799,7 @@ pub mod fsm { } /// Drain the response and throw away the result - pub async fn drain(self) -> result::Result { + pub async fn drain(self) -> result::Result, DecodeError> { let mut content = self; loop { match content.next().await { @@ -769,7 +817,7 @@ pub mod fsm { /// Concatenate the entire response into a vec pub async fn concatenate_into_vec( self, - ) -> result::Result<(AtEndBlob, Vec), DecodeError> { + ) -> result::Result<(AtEndBlob, Vec), DecodeError> { let mut res = Vec::with_capacity(1024); let mut curr = self; let done = loop { @@ -797,7 +845,7 @@ pub mod fsm { self, mut outboard: Option, mut data: D, - ) -> result::Result + ) -> result::Result, DecodeError> where D: AsyncSliceWriter, O: OutboardMut, @@ -810,11 +858,16 @@ pub mod fsm { match item? { BaoContentItem::Parent(parent) => { if let Some(outboard) = outboard.as_mut() { - outboard.save(parent.node, &parent.pair).await?; + outboard + .save(parent.node, &parent.pair) + .await + .map_err(|e| decode_error::WriteSnafu.into_error(e))?; } } BaoContentItem::Leaf(leaf) => { - data.write_bytes_at(leaf.offset, leaf.data).await?; + data.write_bytes_at(leaf.offset, leaf.data) + .await + .map_err(|e| decode_error::WriteSnafu.into_error(e))?; } } } @@ -826,7 +879,7 @@ pub mod fsm { } /// Write the entire blob to a slice writer. - pub async fn write_all(self, mut data: D) -> result::Result + pub async fn write_all(self, mut data: D) -> result::Result, DecodeError> where D: AsyncSliceWriter, { @@ -838,7 +891,9 @@ pub mod fsm { match item? { BaoContentItem::Parent(_) => {} BaoContentItem::Leaf(leaf) => { - data.write_bytes_at(leaf.offset, leaf.data).await?; + data.write_bytes_at(leaf.offset, leaf.data) + .await + .map_err(|e| decode_error::WriteSnafu.into_error(e))?; } } } @@ -850,30 +905,30 @@ pub mod fsm { } /// Immediately finish the get response without reading further - pub fn finish(self) -> AtClosing { - AtClosing::new(self.misc, self.stream.finish(), false) + pub fn finish(self) -> AtClosing { + AtClosing::new(self.misc, self.stream.finish().into_inner(), false) } } /// State after we have read all the content for a blob #[derive(Debug)] - pub struct AtEndBlob { - stream: WrappedRecvStream, + pub struct AtEndBlob { + stream: R, misc: Box, } /// The next state after the end of a blob #[derive(Debug, From)] - pub enum EndBlobNext { + pub enum EndBlobNext { /// Response is expected to have more children - MoreChildren(AtStartChild), + MoreChildren(AtStartChild), /// No more children expected - Closing(AtClosing), + Closing(AtClosing), } - impl AtEndBlob { + impl AtEndBlob { /// Read the next child, or finish - pub fn next(mut self) -> EndBlobNext { + pub fn next(mut self) -> EndBlobNext { if let Some((offset, ranges)) = self.misc.ranges_iter.next() { AtStartChild { reader: self.stream, @@ -890,14 +945,14 @@ pub mod fsm { /// State when finishing the get response #[derive(Debug)] - pub struct AtClosing { + pub struct AtClosing { misc: Box, - reader: WrappedRecvStream, + reader: R, check_extra_data: bool, } - impl AtClosing { - fn new(misc: Box, reader: WrappedRecvStream, check_extra_data: bool) -> Self { + impl AtClosing { + fn new(misc: Box, reader: R, check_extra_data: bool) -> Self { Self { misc, reader, @@ -906,17 +961,14 @@ pub mod fsm { } /// Finish the get response, returning statistics - pub async fn next(self) -> result::Result { + pub async fn next(self) -> result::Result { // Shut down the stream - let reader = self.reader; - let mut reader = reader.into_inner(); + let mut reader = self.reader; if self.check_extra_data { - if let Some(chunk) = reader.read_chunk(8, false).await? { - reader.stop(0u8.into()).ok(); - error!("Received unexpected data from the provider: {chunk:?}"); + let rest = reader.recv_bytes(1).await?; + if !rest.is_empty() { + error!("Unexpected extra data at the end of the stream"); } - } else { - reader.stop(0u8.into()).ok(); } Ok(Stats { counters: self.misc.counters, @@ -925,6 +977,21 @@ pub mod fsm { } } + /// Error that you can get from [`AtBlobHeader::next`] + #[common_fields({ + backtrace: Option, + #[snafu(implicit)] + span_trace: SpanTrace, + })] + #[non_exhaustive] + #[derive(Debug, Snafu)] + #[snafu(module)] + pub enum AtClosingNextError { + /// Generic io error + #[snafu(transparent)] + Read { source: io::Error }, + } + #[derive(Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq, Eq)] pub struct RequestCounters { /// payload bytes written @@ -950,71 +1017,3 @@ pub mod fsm { ranges_iter: RangesIter, } } - -/// Error when processing a response -#[common_fields({ - backtrace: Option, - #[snafu(implicit)] - span_trace: SpanTrace, -})] -#[allow(missing_docs)] -#[non_exhaustive] -#[derive(Debug, Snafu)] -pub enum GetResponseError { - /// Error when opening a stream - #[snafu(display("connection: {source}"))] - Connection { source: endpoint::ConnectionError }, - /// Error when writing the handshake or request to the stream - #[snafu(display("write: {source}"))] - Write { source: endpoint::WriteError }, - /// Error when reading from the stream - #[snafu(display("read: {source}"))] - Read { source: endpoint::ReadError }, - /// Error when decoding, e.g. hash mismatch - #[snafu(display("decode: {source}"))] - Decode { source: bao_tree::io::DecodeError }, - /// A generic error - #[snafu(display("generic: {source}"))] - Generic { source: anyhow::Error }, -} - -impl From for GetResponseError { - fn from(cause: postcard::Error) -> Self { - GenericSnafu.into_error(cause.into()) - } -} - -impl From for GetResponseError { - fn from(cause: bao_tree::io::DecodeError) -> Self { - match cause { - bao_tree::io::DecodeError::Io(cause) => { - // try to downcast to specific quinn errors - if let Some(source) = cause.source() { - if let Some(error) = source.downcast_ref::() { - return ConnectionSnafu.into_error(error.clone()); - } - if let Some(error) = source.downcast_ref::() { - return ReadSnafu.into_error(error.clone()); - } - if let Some(error) = source.downcast_ref::() { - return WriteSnafu.into_error(error.clone()); - } - } - GenericSnafu.into_error(cause.into()) - } - _ => DecodeSnafu.into_error(cause), - } - } -} - -impl From for GetResponseError { - fn from(cause: anyhow::Error) -> Self { - GenericSnafu.into_error(cause) - } -} - -impl From for std::io::Error { - fn from(cause: GetResponseError) -> Self { - Self::other(cause) - } -} diff --git a/src/get/error.rs b/src/get/error.rs index 1c3ea9465..5cc44e35b 100644 --- a/src/get/error.rs +++ b/src/get/error.rs @@ -1,102 +1,15 @@ //! Error returned from get operations use std::io; -use iroh::endpoint::{self, ClosedStream}; +use iroh::endpoint::{ConnectionError, ReadError, VarInt, WriteError}; use n0_snafu::SpanTrace; use nested_enum_utils::common_fields; -use quinn::{ConnectionError, ReadError, WriteError}; -use snafu::{Backtrace, IntoError, Snafu}; +use snafu::{Backtrace, Snafu}; -use crate::{ - api::ExportBaoError, - get::fsm::{AtBlobHeaderNextError, ConnectedNextError, DecodeError}, +use crate::get::fsm::{ + AtBlobHeaderNextError, AtClosingNextError, ConnectedNextError, DecodeError, InitialNextError, }; -#[derive(Debug, Snafu)] -pub enum NotFoundCases { - #[snafu(transparent)] - AtBlobHeaderNext { source: AtBlobHeaderNextError }, - #[snafu(transparent)] - Decode { source: DecodeError }, -} - -#[derive(Debug, Snafu)] -pub enum NoncompliantNodeCases { - #[snafu(transparent)] - Connection { source: ConnectionError }, - #[snafu(transparent)] - Decode { source: DecodeError }, -} - -#[derive(Debug, Snafu)] -pub enum RemoteResetCases { - #[snafu(transparent)] - Read { source: ReadError }, - #[snafu(transparent)] - Write { source: WriteError }, - #[snafu(transparent)] - Connection { source: ConnectionError }, -} - -#[derive(Debug, Snafu)] -pub enum BadRequestCases { - #[snafu(transparent)] - Anyhow { source: anyhow::Error }, - #[snafu(transparent)] - Postcard { source: postcard::Error }, - #[snafu(transparent)] - ConnectedNext { source: ConnectedNextError }, -} - -#[derive(Debug, Snafu)] -pub enum LocalFailureCases { - #[snafu(transparent)] - Io { - source: io::Error, - }, - #[snafu(transparent)] - Anyhow { - source: anyhow::Error, - }, - #[snafu(transparent)] - IrpcSend { - source: irpc::channel::SendError, - }, - #[snafu(transparent)] - Irpc { - source: irpc::Error, - }, - #[snafu(transparent)] - ExportBao { - source: ExportBaoError, - }, - TokioSend {}, -} - -impl From> for LocalFailureCases { - fn from(_: tokio::sync::mpsc::error::SendError) -> Self { - LocalFailureCases::TokioSend {} - } -} - -#[derive(Debug, Snafu)] -pub enum IoCases { - #[snafu(transparent)] - Io { source: io::Error }, - #[snafu(transparent)] - ConnectionError { source: endpoint::ConnectionError }, - #[snafu(transparent)] - ReadError { source: endpoint::ReadError }, - #[snafu(transparent)] - WriteError { source: endpoint::WriteError }, - #[snafu(transparent)] - ClosedStream { source: endpoint::ClosedStream }, - #[snafu(transparent)] - ConnectedNextError { source: ConnectedNextError }, - #[snafu(transparent)] - AtBlobHeaderNextError { source: AtBlobHeaderNextError }, -} - /// Failures for a get operation #[common_fields({ backtrace: Option, @@ -105,210 +18,112 @@ pub enum IoCases { })] #[derive(Debug, Snafu)] #[snafu(visibility(pub(crate)))] +#[snafu(module)] pub enum GetError { - /// Hash not found, or a requested chunk for the hash not found. - #[snafu(display("Data for hash not found"))] - NotFound { - #[snafu(source(from(NotFoundCases, Box::new)))] - source: Box, + #[snafu(transparent)] + InitialNext { + source: InitialNextError, }, - /// Remote has reset the connection. - #[snafu(display("Remote has reset the connection"))] - RemoteReset { - #[snafu(source(from(RemoteResetCases, Box::new)))] - source: Box, + #[snafu(transparent)] + ConnectedNext { + source: ConnectedNextError, }, - /// Remote behaved in a non-compliant way. - #[snafu(display("Remote behaved in a non-compliant way"))] - NoncompliantNode { - #[snafu(source(from(NoncompliantNodeCases, Box::new)))] - source: Box, + #[snafu(transparent)] + AtBlobHeaderNext { + source: AtBlobHeaderNextError, }, - - /// Network or IO operation failed. - #[snafu(display("A network or IO operation failed"))] - Io { - #[snafu(source(from(IoCases, Box::new)))] - source: Box, + #[snafu(transparent)] + Decode { + source: DecodeError, }, - /// Our download request is invalid. - #[snafu(display("Our download request is invalid"))] - BadRequest { - #[snafu(source(from(BadRequestCases, Box::new)))] - source: Box, + #[snafu(transparent)] + IrpcSend { + source: irpc::channel::SendError, + }, + #[snafu(transparent)] + AtClosingNext { + source: AtClosingNextError, }, - /// Operation failed on the local node. - #[snafu(display("Operation failed on the local node"))] LocalFailure { - #[snafu(source(from(LocalFailureCases, Box::new)))] - source: Box, + source: anyhow::Error, + }, + BadRequest { + source: anyhow::Error, }, } -pub type GetResult = std::result::Result; - -impl From for GetError { - fn from(value: irpc::channel::SendError) -> Self { - LocalFailureSnafu.into_error(value.into()) - } -} - -impl From> for GetError { - fn from(value: tokio::sync::mpsc::error::SendError) -> Self { - LocalFailureSnafu.into_error(value.into()) - } -} - -impl From for GetError { - fn from(value: endpoint::ConnectionError) -> Self { - // explicit match just to be sure we are taking everything into account - use endpoint::ConnectionError; - match value { - e @ ConnectionError::VersionMismatch => { - // > The peer doesn't implement any supported version - // unsupported version is likely a long time error, so this peer is not usable - NoncompliantNodeSnafu.into_error(e.into()) - } - e @ ConnectionError::TransportError(_) => { - // > The peer violated the QUIC specification as understood by this implementation - // bad peer we don't want to keep around - NoncompliantNodeSnafu.into_error(e.into()) - } - e @ ConnectionError::ConnectionClosed(_) => { - // > The peer's QUIC stack aborted the connection automatically - // peer might be disconnecting or otherwise unavailable, drop it - IoSnafu.into_error(e.into()) - } - e @ ConnectionError::ApplicationClosed(_) => { - // > The peer closed the connection - // peer might be disconnecting or otherwise unavailable, drop it - IoSnafu.into_error(e.into()) - } - e @ ConnectionError::Reset => { - // > The peer is unable to continue processing this connection, usually due to having restarted - RemoteResetSnafu.into_error(e.into()) - } - e @ ConnectionError::TimedOut => { - // > Communication with the peer has lapsed for longer than the negotiated idle timeout - IoSnafu.into_error(e.into()) - } - e @ ConnectionError::LocallyClosed => { - // > The local application closed the connection - // TODO(@divma): don't see how this is reachable but let's just not use the peer - IoSnafu.into_error(e.into()) - } - e @ ConnectionError::CidsExhausted => { - // > The connection could not be created because not enough of the CID space - // > is available - IoSnafu.into_error(e.into()) - } - } - } -} - -impl From for GetError { - fn from(value: endpoint::ReadError) -> Self { - use endpoint::ReadError; - match value { - e @ ReadError::Reset(_) => RemoteResetSnafu.into_error(e.into()), - ReadError::ConnectionLost(conn_error) => conn_error.into(), - ReadError::ClosedStream - | ReadError::IllegalOrderedRead - | ReadError::ZeroRttRejected => { - // all these errors indicate the peer is not usable at this moment - IoSnafu.into_error(value.into()) - } +impl GetError { + pub fn iroh_error_code(&self) -> Option { + if let Some(ReadError::Reset(code)) = self + .remote_read() + .and_then(|source| source.get_ref()) + .and_then(|e| e.downcast_ref::()) + { + Some(*code) + } else if let Some(WriteError::Stopped(code)) = self + .remote_write() + .and_then(|source| source.get_ref()) + .and_then(|e| e.downcast_ref::()) + { + Some(*code) + } else if let Some(ConnectionError::ApplicationClosed(ac)) = self + .open() + .and_then(|source| source.get_ref()) + .and_then(|e| e.downcast_ref::()) + { + Some(ac.error_code) + } else { + None } } -} -impl From for GetError { - fn from(value: ClosedStream) -> Self { - IoSnafu.into_error(value.into()) - } -} -impl From for GetError { - fn from(value: quinn::WriteError) -> Self { - use quinn::WriteError; - match value { - e @ WriteError::Stopped(_) => RemoteResetSnafu.into_error(e.into()), - WriteError::ConnectionLost(conn_error) => conn_error.into(), - WriteError::ClosedStream | WriteError::ZeroRttRejected => { - // all these errors indicate the peer is not usable at this moment - IoSnafu.into_error(value.into()) - } + pub fn remote_write(&self) -> Option<&io::Error> { + match self { + Self::ConnectedNext { + source: ConnectedNextError::Write { source, .. }, + .. + } => Some(source), + _ => None, } } -} -impl From for GetError { - fn from(value: crate::get::fsm::ConnectedNextError) -> Self { - use crate::get::fsm::ConnectedNextError::*; - match value { - e @ PostcardSer { .. } => { - // serialization errors indicate something wrong with the request itself - BadRequestSnafu.into_error(e.into()) - } - e @ RequestTooBig { .. } => { - // request will never be sent, drop it - BadRequestSnafu.into_error(e.into()) - } - Write { source, .. } => source.into(), - Closed { source, .. } => source.into(), - e @ Io { .. } => { - // io errors are likely recoverable - IoSnafu.into_error(e.into()) - } + pub fn open(&self) -> Option<&io::Error> { + match self { + Self::InitialNext { + source: InitialNextError::Open { source, .. }, + .. + } => Some(source), + _ => None, } } -} -impl From for GetError { - fn from(value: crate::get::fsm::AtBlobHeaderNextError) -> Self { - use crate::get::fsm::AtBlobHeaderNextError::*; - match value { - e @ NotFound { .. } => { - // > This indicates that the provider does not have the requested data. - // peer might have the data later, simply retry it - NotFoundSnafu.into_error(e.into()) - } - EndpointRead { source, .. } => source.into(), - e @ Io { .. } => { - // io errors are likely recoverable - IoSnafu.into_error(e.into()) - } + pub fn remote_read(&self) -> Option<&io::Error> { + match self { + Self::AtBlobHeaderNext { + source: AtBlobHeaderNextError::Read { source, .. }, + .. + } => Some(source), + Self::Decode { + source: DecodeError::Read { source, .. }, + .. + } => Some(source), + Self::AtClosingNext { + source: AtClosingNextError::Read { source, .. }, + .. + } => Some(source), + _ => None, } } -} - -impl From for GetError { - fn from(value: crate::get::fsm::DecodeError) -> Self { - use crate::get::fsm::DecodeError::*; - match value { - e @ ChunkNotFound { .. } => NotFoundSnafu.into_error(e.into()), - e @ ParentNotFound { .. } => NotFoundSnafu.into_error(e.into()), - e @ LeafNotFound { .. } => NotFoundSnafu.into_error(e.into()), - e @ ParentHashMismatch { .. } => { - // TODO(@divma): did the peer sent wrong data? is it corrupted? did we sent a wrong - // request? - NoncompliantNodeSnafu.into_error(e.into()) - } - e @ LeafHashMismatch { .. } => { - // TODO(@divma): did the peer sent wrong data? is it corrupted? did we sent a wrong - // request? - NoncompliantNodeSnafu.into_error(e.into()) - } - Read { source, .. } => source.into(), - DecodeIo { source, .. } => source.into(), + pub fn local_write(&self) -> Option<&io::Error> { + match self { + Self::Decode { + source: DecodeError::Write { source, .. }, + .. + } => Some(source), + _ => None, } } } -impl From for GetError { - fn from(value: std::io::Error) -> Self { - // generally consider io errors recoverable - // we might want to revisit this at some point - IoSnafu.into_error(value.into()) - } -} +pub type GetResult = std::result::Result; diff --git a/src/get/request.rs b/src/get/request.rs index 98563057e..c1dc034d3 100644 --- a/src/get/request.rs +++ b/src/get/request.rs @@ -25,7 +25,7 @@ use tokio::sync::mpsc; use super::{fsm, GetError, GetResult, Stats}; use crate::{ - get::error::{BadRequestSnafu, LocalFailureSnafu}, + get::get_error::{BadRequestSnafu, LocalFailureSnafu}, hashseq::HashSeq, protocol::{ChunkRangesExt, ChunkRangesSeq, GetRequest}, Hash, HashAndFormat, @@ -58,7 +58,7 @@ impl GetBlobResult { let mut parts = Vec::new(); let stats = loop { let Some(item) = self.next().await else { - return Err(LocalFailureSnafu.into_error(anyhow::anyhow!("unexpected end").into())); + return Err(LocalFailureSnafu.into_error(anyhow::anyhow!("unexpected end"))); }; match item { GetBlobItem::Item(item) => { @@ -238,11 +238,11 @@ pub async fn get_hash_seq_and_sizes( let (at_blob_content, size) = at_start_root.next().await?; // check the size to avoid parsing a maliciously large hash seq if size > max_size { - return Err(BadRequestSnafu.into_error(anyhow::anyhow!("size too large").into())); + return Err(BadRequestSnafu.into_error(anyhow::anyhow!("size too large"))); } let (mut curr, hash_seq) = at_blob_content.concatenate_into_vec().await?; - let hash_seq = HashSeq::try_from(Bytes::from(hash_seq)) - .map_err(|e| BadRequestSnafu.into_error(e.into()))?; + let hash_seq = + HashSeq::try_from(Bytes::from(hash_seq)).map_err(|e| BadRequestSnafu.into_error(e))?; let mut sizes = Vec::with_capacity(hash_seq.len()); let closing = loop { match curr.next() { diff --git a/src/protocol.rs b/src/protocol.rs index ce10865a5..db5faf060 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -382,7 +382,6 @@ use bao_tree::{io::round_up_to_chunks, ChunkNum}; use builder::GetRequestBuilder; use derive_more::From; use iroh::endpoint::VarInt; -use irpc::util::AsyncReadVarintExt; use postcard::experimental::max_size::MaxSize; use range_collections::{range_set::RangeSetEntry, RangeSet2}; use serde::{Deserialize, Serialize}; @@ -390,9 +389,8 @@ mod range_spec; pub use bao_tree::ChunkRanges; pub use range_spec::{ChunkRangesSeq, NonEmptyRequestRangeSpecIter, RangeSpec}; use snafu::{GenerateImplicitData, Snafu}; -use tokio::io::AsyncReadExt; -use crate::{api::blobs::Bitfield, provider::RecvStreamExt, BlobFormat, Hash, HashAndFormat}; +use crate::{api::blobs::Bitfield, util::RecvStreamExt, BlobFormat, Hash, HashAndFormat}; /// Maximum message size is limited to 100MiB for now. pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; @@ -448,7 +446,9 @@ pub enum RequestType { } impl Request { - pub async fn read_async(reader: &mut iroh::endpoint::RecvStream) -> io::Result<(Self, usize)> { + pub async fn read_async( + reader: &mut R, + ) -> io::Result<(Self, usize)> { let request_type = reader.read_u8().await?; let request_type: RequestType = postcard::from_bytes(std::slice::from_ref(&request_type)) .map_err(|_| { diff --git a/src/provider.rs b/src/provider.rs index ba415df41..904a272fe 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -5,32 +5,44 @@ //! handler with an [`iroh::Endpoint`](iroh::protocol::Router). use std::{ fmt::Debug, + future::Future, io, time::{Duration, Instant}, }; -use anyhow::{Context, Result}; +use anyhow::Result; use bao_tree::ChunkRanges; -use iroh::endpoint::{self, RecvStream, SendStream}; +use iroh::endpoint::{self, VarInt}; +use iroh_io::{AsyncStreamReader, AsyncStreamWriter}; use n0_future::StreamExt; -use quinn::{ClosedStream, ConnectionError, ReadToEndError}; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use quinn::ConnectionError; +use serde::{Deserialize, Serialize}; +use snafu::Snafu; use tokio::select; use tracing::{debug, debug_span, Instrument}; use crate::{ api::{ blobs::{Bitfield, WriteProgress}, - ExportBaoResult, Store, + ExportBaoError, ExportBaoResult, RequestError, Store, }, hashseq::HashSeq, - protocol::{GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request}, - provider::events::{ClientConnected, ClientResult, ConnectionClosed, RequestTracker}, + protocol::{ + GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request, ERR_INTERNAL, + }, + provider::events::{ + ClientConnected, ClientResult, ConnectionClosed, HasErrorCode, ProgressError, + RequestTracker, + }, + util::{RecvStream, RecvStreamExt, SendStream, SendStreamExt}, Hash, }; pub mod events; use events::EventSender; +type DefaultReader = iroh::endpoint::RecvStream; +type DefaultWriter = iroh::endpoint::SendStream; + /// Statistics about a successful or failed transfer. #[derive(Debug, Serialize, Deserialize)] pub struct TransferStats { @@ -51,12 +63,11 @@ pub struct TransferStats { /// A pair of [`SendStream`] and [`RecvStream`] with additional context data. #[derive(Debug)] -pub struct StreamPair { +pub struct StreamPair { t0: Instant, connection_id: u64, - request_id: u64, - reader: RecvStream, - writer: SendStream, + reader: R, + writer: W, other_bytes_read: u64, events: EventSender, } @@ -64,18 +75,27 @@ pub struct StreamPair { impl StreamPair { pub async fn accept( conn: &endpoint::Connection, - events: &EventSender, + events: EventSender, ) -> Result { let (writer, reader) = conn.accept_bi().await?; - Ok(Self { + Ok(Self::new(conn.stable_id() as u64, reader, writer, events)) + } +} + +impl StreamPair { + pub fn stream_id(&self) -> u64 { + self.reader.id() + } + + pub fn new(connection_id: u64, reader: R, writer: W, events: EventSender) -> Self { + Self { t0: Instant::now(), - connection_id: conn.stable_id() as u64, - request_id: reader.id().into(), + connection_id, reader, writer, other_bytes_read: 0, - events: events.clone(), - }) + events, + } } /// Read the request. @@ -93,18 +113,12 @@ impl StreamPair { } /// We are done with reading. Return a ProgressWriter that contains the read stats and connection id - async fn into_writer( + pub async fn into_writer( mut self, tracker: RequestTracker, - ) -> Result { - let res = self.reader.read_to_end(0).await; - if let Err(e) = res { - tracker - .transfer_aborted(|| Box::new(self.stats())) - .await - .ok(); - return Err(e); - }; + ) -> Result, io::Error> { + self.reader.expect_eof().await?; + drop(self.reader); Ok(ProgressWriter::new( self.writer, WriterContext { @@ -117,18 +131,12 @@ impl StreamPair { )) } - async fn into_reader( + pub async fn into_reader( mut self, tracker: RequestTracker, - ) -> Result { - let res = self.writer.finish(); - if let Err(e) = res { - tracker - .transfer_aborted(|| Box::new(self.stats())) - .await - .ok(); - return Err(e); - }; + ) -> Result, io::Error> { + self.writer.sync().await?; + drop(self.writer); Ok(ProgressReader { inner: self.reader, context: ReaderContext { @@ -140,74 +148,42 @@ impl StreamPair { } pub async fn get_request( - mut self, + &self, f: impl FnOnce() -> GetRequest, - ) -> anyhow::Result { - let res = self - .events - .request(f, self.connection_id, self.request_id) - .await; - match res { - Err(e) => { - self.writer.reset(e.code()).ok(); - Err(e.into()) - } - Ok(tracker) => Ok(self.into_writer(tracker).await?), - } + ) -> Result { + self.events + .request(f, self.connection_id, self.reader.id()) + .await } pub async fn get_many_request( - mut self, + &self, f: impl FnOnce() -> GetManyRequest, - ) -> anyhow::Result { - let res = self - .events - .request(f, self.connection_id, self.request_id) - .await; - match res { - Err(e) => { - self.writer.reset(e.code()).ok(); - Err(e.into()) - } - Ok(tracker) => Ok(self.into_writer(tracker).await?), - } + ) -> Result { + self.events + .request(f, self.connection_id, self.reader.id()) + .await } pub async fn push_request( - mut self, + &self, f: impl FnOnce() -> PushRequest, - ) -> anyhow::Result { - let res = self - .events - .request(f, self.connection_id, self.request_id) - .await; - match res { - Err(e) => { - self.writer.reset(e.code()).ok(); - Err(e.into()) - } - Ok(tracker) => Ok(self.into_reader(tracker).await?), - } + ) -> Result { + self.events + .request(f, self.connection_id, self.reader.id()) + .await } pub async fn observe_request( - mut self, + &self, f: impl FnOnce() -> ObserveRequest, - ) -> anyhow::Result { - let res = self - .events - .request(f, self.connection_id, self.request_id) - .await; - match res { - Err(e) => { - self.writer.reset(e.code()).ok(); - Err(e.into()) - } - Ok(tracker) => Ok(self.into_writer(tracker).await?), - } + ) -> Result { + self.events + .request(f, self.connection_id, self.reader.id()) + .await } - fn stats(&self) -> TransferStats { + pub fn stats(&self) -> TransferStats { TransferStats { payload_bytes_sent: 0, other_bytes_sent: 0, @@ -282,14 +258,14 @@ impl WriteProgress for WriterContext { /// Wrapper for a [`quinn::SendStream`] with additional per request information. #[derive(Debug)] -pub struct ProgressWriter { +pub struct ProgressWriter { /// The quinn::SendStream to write to - pub inner: SendStream, + pub inner: W, pub(crate) context: WriterContext, } -impl ProgressWriter { - fn new(inner: SendStream, context: WriterContext) -> Self { +impl ProgressWriter { + fn new(inner: W, context: WriterContext) -> Self { Self { inner, context } } @@ -330,10 +306,10 @@ pub async fn handle_connection( debug!("closing connection: {cause}"); return; } - while let Ok(context) = StreamPair::accept(&connection, &progress).await { - let span = debug_span!("stream", stream_id = %context.request_id); + while let Ok(pair) = StreamPair::accept(&connection, progress.clone()).await { + let span = debug_span!("stream", stream_id = %pair.stream_id()); let store = store.clone(); - tokio::spawn(handle_stream(store, context).instrument(span)); + tokio::spawn(handle_stream(pair, store).instrument(span)); } progress .connection_closed(|| ConnectionClosed { connection_id }) @@ -344,58 +320,106 @@ pub async fn handle_connection( .await } -async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result<()> { - // 1. Decode the request. - debug!("reading request"); - let request = context.read_request().await?; +/// Describes how to handle errors for a stream. +pub trait ErrorHandler { + type W: AsyncStreamWriter; + type R: AsyncStreamReader; + fn stop(reader: &mut Self::R, code: VarInt) -> impl Future; + fn reset(writer: &mut Self::W, code: VarInt) -> impl Future; +} - match request { - Request::Get(request) => { - let mut writer = context.get_request(|| request.clone()).await?; - let res = handle_get(store, request, &mut writer).await; - if res.is_ok() { - writer.transfer_completed().await; - } else { - writer.transfer_aborted().await; - } +async fn handle_read_request_result( + pair: &mut StreamPair, + r: Result, +) -> Result { + match r { + Ok(x) => Ok(x), + Err(e) => { + pair.writer.reset(e.code()).ok(); + Err(e) } - Request::GetMany(request) => { - let mut writer = context.get_many_request(|| request.clone()).await?; - if handle_get_many(store, request, &mut writer).await.is_ok() { - writer.transfer_completed().await; - } else { - writer.transfer_aborted().await; - } + } +} +async fn handle_write_result( + writer: &mut ProgressWriter, + r: Result, +) -> Result { + match r { + Ok(x) => { + writer.transfer_completed().await; + Ok(x) } - Request::Observe(request) => { - let mut writer = context.observe_request(|| request.clone()).await?; - if handle_observe(store, request, &mut writer).await.is_ok() { - writer.transfer_completed().await; - } else { - writer.transfer_aborted().await; - } + Err(e) => { + writer.inner.reset(e.code()).ok(); + writer.transfer_aborted().await; + Err(e) } - Request::Push(request) => { - let mut reader = context.push_request(|| request.clone()).await?; - if handle_push(store, request, &mut reader).await.is_ok() { - reader.transfer_completed().await; - } else { - reader.transfer_aborted().await; - } + } +} +async fn handle_read_result( + reader: &mut ProgressReader, + r: Result, +) -> Result { + match r { + Ok(x) => { + reader.transfer_completed().await; + Ok(x) + } + Err(e) => { + reader.inner.stop(e.code()).ok(); + reader.transfer_aborted().await; + Err(e) } + } +} + +pub async fn handle_stream( + mut pair: StreamPair, + store: Store, +) -> anyhow::Result<()> { + let request = pair.read_request().await?; + match request { + Request::Get(request) => handle_get(pair, store, request).await?, + Request::GetMany(request) => handle_get_many(pair, store, request).await?, + Request::Observe(request) => handle_observe(pair, store, request).await?, + Request::Push(request) => handle_push(pair, store, request).await?, _ => {} } Ok(()) } +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum HandleGetError { + #[snafu(transparent)] + ExportBao { + source: ExportBaoError, + }, + InvalidHashSeq, + InvalidOffset, +} + +impl HasErrorCode for HandleGetError { + fn code(&self) -> VarInt { + match self { + HandleGetError::ExportBao { + source: ExportBaoError::ClientError { source, .. }, + } => source.code(), + HandleGetError::InvalidHashSeq => ERR_INTERNAL, + HandleGetError::InvalidOffset => ERR_INTERNAL, + _ => ERR_INTERNAL, + } + } +} + /// Handle a single get request. /// /// Requires a database, the request, and a writer. -pub async fn handle_get( +async fn handle_get_impl( store: Store, request: GetRequest, - writer: &mut ProgressWriter, -) -> anyhow::Result<()> { + writer: &mut ProgressWriter, +) -> Result<(), HandleGetError> { let hash = request.hash; debug!(%hash, "get received request"); let mut hash_seq = None; @@ -412,30 +436,66 @@ pub async fn handle_get( Some(b) => b, None => { let bytes = store.get_bytes(hash).await?; - let hs = HashSeq::try_from(bytes)?; + let hs = + HashSeq::try_from(bytes).map_err(|_| HandleGetError::InvalidHashSeq)?; hash_seq = Some(hs); hash_seq.as_ref().unwrap() } }; - let o = usize::try_from(offset - 1).context("offset too large")?; + let o = usize::try_from(offset - 1).map_err(|_| HandleGetError::InvalidOffset)?; let Some(hash) = hash_seq.get(o) else { break; }; send_blob(&store, offset, hash, ranges.clone(), writer).await?; } } + writer + .inner + .sync() + .await + .map_err(|e| HandleGetError::ExportBao { source: e.into() })?; Ok(()) } +pub async fn handle_get( + mut pair: StreamPair, + store: Store, + request: GetRequest, +) -> anyhow::Result<()> { + let res = pair.get_request(|| request.clone()).await; + let tracker = handle_read_request_result(&mut pair, res).await?; + let mut writer = pair.into_writer(tracker).await?; + let res = handle_get_impl(store, request, &mut writer).await; + handle_write_result(&mut writer, res).await?; + Ok(()) +} + +#[derive(Debug, Snafu)] +pub enum HandleGetManyError { + #[snafu(transparent)] + ExportBao { source: ExportBaoError }, +} + +impl HasErrorCode for HandleGetManyError { + fn code(&self) -> VarInt { + match self { + Self::ExportBao { + source: ExportBaoError::ClientError { source, .. }, + } => source.code(), + _ => ERR_INTERNAL, + } + } +} + /// Handle a single get request. /// /// Requires a database, the request, and a writer. -pub async fn handle_get_many( +async fn handle_get_many_impl( store: Store, request: GetManyRequest, - writer: &mut ProgressWriter, -) -> Result<()> { + writer: &mut ProgressWriter, +) -> Result<(), HandleGetManyError> { debug!("get_many received request"); let request_ranges = request.ranges.iter_infinite(); for (child, (hash, ranges)) in request.hashes.iter().zip(request_ranges).enumerate() { @@ -446,14 +506,53 @@ pub async fn handle_get_many( Ok(()) } +pub async fn handle_get_many( + mut pair: StreamPair, + store: Store, + request: GetManyRequest, +) -> anyhow::Result<()> { + let res = pair.get_many_request(|| request.clone()).await; + let tracker = handle_read_request_result(&mut pair, res).await?; + let mut writer = pair.into_writer(tracker).await?; + let res = handle_get_many_impl(store, request, &mut writer).await; + handle_write_result(&mut writer, res).await?; + Ok(()) +} + +#[derive(Debug, Snafu)] +pub enum HandlePushError { + #[snafu(transparent)] + ExportBao { + source: ExportBaoError, + }, + + InvalidHashSeq, + + #[snafu(transparent)] + Request { + source: RequestError, + }, +} + +impl HasErrorCode for HandlePushError { + fn code(&self) -> VarInt { + match self { + Self::ExportBao { + source: ExportBaoError::ClientError { source, .. }, + } => source.code(), + _ => ERR_INTERNAL, + } + } +} + /// Handle a single push request. /// /// Requires a database, the request, and a reader. -pub async fn handle_push( +async fn handle_push_impl( store: Store, request: PushRequest, - reader: &mut ProgressReader, -) -> Result<()> { + reader: &mut ProgressReader, +) -> Result<(), HandlePushError> { let hash = request.hash; debug!(%hash, "push received request"); let mut request_ranges = request.ranges.iter_infinite(); @@ -461,7 +560,7 @@ pub async fn handle_push( if !root_ranges.is_empty() { // todo: send progress from import_bao_quinn or rename to import_bao_quinn_with_progress store - .import_bao_quinn(hash, root_ranges.clone(), &mut reader.inner) + .import_bao_reader(hash, root_ranges.clone(), &mut reader.inner) .await?; } if request.ranges.is_blob() { @@ -470,52 +569,85 @@ pub async fn handle_push( } // todo: we assume here that the hash sequence is complete. For some requests this might not be the case. We would need `LazyHashSeq` for that, but it is buggy as of now! let hash_seq = store.get_bytes(hash).await?; - let hash_seq = HashSeq::try_from(hash_seq)?; + let hash_seq = HashSeq::try_from(hash_seq).map_err(|_| HandlePushError::InvalidHashSeq)?; for (child_hash, child_ranges) in hash_seq.into_iter().zip(request_ranges) { if child_ranges.is_empty() { continue; } store - .import_bao_quinn(child_hash, child_ranges.clone(), &mut reader.inner) + .import_bao_reader(child_hash, child_ranges.clone(), &mut reader.inner) .await?; } Ok(()) } +pub async fn handle_push( + mut pair: StreamPair, + store: Store, + request: PushRequest, +) -> anyhow::Result<()> { + let res = pair.push_request(|| request.clone()).await; + let tracker = handle_read_request_result(&mut pair, res).await?; + let mut reader = pair.into_reader(tracker).await?; + let res = handle_push_impl(store, request, &mut reader).await; + handle_read_result(&mut reader, res).await?; + Ok(()) +} + /// Send a blob to the client. -pub(crate) async fn send_blob( +pub(crate) async fn send_blob( store: &Store, index: u64, hash: Hash, ranges: ChunkRanges, - writer: &mut ProgressWriter, + writer: &mut ProgressWriter, ) -> ExportBaoResult<()> { store .export_bao(hash, ranges) - .write_quinn_with_progress(&mut writer.inner, &mut writer.context, &hash, index) + .write_with_progress(&mut writer.inner, &mut writer.context, &hash, index) .await } +#[derive(Debug, Snafu)] +pub enum HandleObserveError { + ObserveStreamClosed, + + #[snafu(transparent)] + RemoteClosed { + source: io::Error, + }, +} + +impl HasErrorCode for HandleObserveError { + fn code(&self) -> VarInt { + ERR_INTERNAL + } +} + /// Handle a single push request. /// /// Requires a database, the request, and a reader. -pub async fn handle_observe( +async fn handle_observe_impl( store: Store, request: ObserveRequest, - writer: &mut ProgressWriter, -) -> Result<()> { - let mut stream = store.observe(request.hash).stream().await?; + writer: &mut ProgressWriter, +) -> std::result::Result<(), HandleObserveError> { + let mut stream = store + .observe(request.hash) + .stream() + .await + .map_err(|_| HandleObserveError::ObserveStreamClosed)?; let mut old = stream .next() .await - .ok_or(anyhow::anyhow!("observe stream closed before first value"))?; + .ok_or(HandleObserveError::ObserveStreamClosed)?; // send the initial bitfield send_observe_item(writer, &old).await?; // send updates until the remote loses interest loop { select! { new = stream.next() => { - let new = new.context("observe stream closed")?; + let new = new.ok_or(HandleObserveError::ObserveStreamClosed)?; let diff = old.diff(&new); if diff.is_empty() { continue; @@ -532,20 +664,35 @@ pub async fn handle_observe( Ok(()) } -async fn send_observe_item(writer: &mut ProgressWriter, item: &Bitfield) -> Result<()> { - use irpc::util::AsyncWriteVarintExt; +async fn send_observe_item( + writer: &mut ProgressWriter, + item: &Bitfield, +) -> io::Result<()> { let item = ObserveItem::from(item); let len = writer.inner.write_length_prefixed(item).await?; writer.context.log_other_write(len); Ok(()) } -pub struct ProgressReader { - inner: RecvStream, +pub async fn handle_observe( + mut pair: StreamPair, + store: Store, + request: ObserveRequest, +) -> anyhow::Result<()> { + let res = pair.observe_request(|| request.clone()).await; + let tracker = handle_read_request_result(&mut pair, res).await?; + let mut writer = pair.into_writer(tracker).await?; + let res = handle_observe_impl(store, request, &mut writer).await; + handle_write_result(&mut writer, res).await?; + Ok(()) +} + +pub struct ProgressReader { + inner: R, context: ReaderContext, } -impl ProgressReader { +impl ProgressReader { async fn transfer_aborted(&self) { self.context .tracker @@ -562,25 +709,3 @@ impl ProgressReader { .ok(); } } - -pub(crate) trait RecvStreamExt { - async fn read_to_end_as( - &mut self, - max_size: usize, - ) -> io::Result<(T, usize)>; -} - -impl RecvStreamExt for RecvStream { - async fn read_to_end_as( - &mut self, - max_size: usize, - ) -> io::Result<(T, usize)> { - let data = self - .read_to_end(max_size) - .await - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - let value = postcard::from_bytes(&data) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - Ok((value, data.len())) - } -} diff --git a/src/provider/events.rs b/src/provider/events.rs index 40ec56f89..85a4dbcb2 100644 --- a/src/provider/events.rs +++ b/src/provider/events.rs @@ -105,15 +105,21 @@ impl From for io::Error { } } -impl ProgressError { - pub fn code(&self) -> quinn::VarInt { +pub trait HasErrorCode { + fn code(&self) -> quinn::VarInt; +} + +impl HasErrorCode for ProgressError { + fn code(&self) -> quinn::VarInt { match self { ProgressError::Limit => ERR_LIMIT, ProgressError::Permission => ERR_PERMISSION, ProgressError::Internal { .. } => ERR_INTERNAL, } } +} +impl ProgressError { pub fn reason(&self) -> &'static [u8] { match self { ProgressError::Limit => b"limit", diff --git a/src/store/fs/util/entity_manager.rs b/src/store/fs/util/entity_manager.rs index 91a737d76..b0b2898ea 100644 --- a/src/store/fs/util/entity_manager.rs +++ b/src/store/fs/util/entity_manager.rs @@ -1186,10 +1186,6 @@ mod tests { .spawn(id, move |arg| async move { match arg { SpawnArg::Active(state) => { - println!( - "Adding value {} to entity actor with id {:?}", - value, state.id - ); state .with_value(|v| *v = v.wrapping_add(value)) .await diff --git a/src/tests.rs b/src/tests.rs index 09b2e5b33..cbd429eb7 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -556,6 +556,7 @@ async fn two_nodes_hash_seq( } #[tokio::test] + async fn two_nodes_hash_seq_fs() -> TestResult<()> { tracing_subscriber::fmt::try_init().ok(); let (_testdir, (r1, store1, _), (r2, store2, _)) = two_node_test_setup_fs().await?; @@ -578,9 +579,7 @@ async fn two_nodes_hash_seq_progress() -> TestResult<()> { let root = add_test_hash_seq(&store1, sizes).await?; let conn = r2.endpoint().connect(addr1, crate::ALPN).await?; let mut stream = store2.remote().fetch(conn, root).stream(); - while let Some(item) = stream.next().await { - println!("{item:?}"); - } + while stream.next().await.is_some() {} check_presence(&store2, &sizes).await?; Ok(()) } @@ -648,9 +647,7 @@ async fn node_serve_blobs() -> TestResult<()> { let expected = test_data(size); let hash = Hash::new(&expected); let mut stream = get::request::get_blob(conn.clone(), hash); - while let Some(item) = stream.next().await { - println!("{item:?}"); - } + while stream.next().await.is_some() {} let actual = get::request::get_blob(conn.clone(), hash).await?; assert_eq!(actual.len(), expected.len(), "size: {size}"); } diff --git a/src/util.rs b/src/util.rs index bc9c25694..c0acfcaad 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,7 +1,13 @@ //! Utilities pub(crate) mod channel; pub mod connection_pool; +mod stream; pub(crate) mod temp_tag; +pub use stream::{ + AsyncReadRecvStream, AsyncReadRecvStreamExtra, AsyncWriteSendStream, AsyncWriteSendStreamExtra, + RecvStream, RecvStreamAsyncStreamReader, SendStream, +}; +pub(crate) use stream::{RecvStreamExt, SendStreamExt}; pub(crate) mod serde { // Module that handles io::Error serialization/deserialization diff --git a/src/util/stream.rs b/src/util/stream.rs new file mode 100644 index 000000000..2816338b1 --- /dev/null +++ b/src/util/stream.rs @@ -0,0 +1,469 @@ +use std::{ + future::Future, + io, + ops::{Deref, DerefMut}, +}; + +use bytes::Bytes; +use iroh::endpoint::{ReadExactError, VarInt}; +use iroh_io::AsyncStreamReader; +use serde::{de::DeserializeOwned, Serialize}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +/// An abstract `iroh::endpoint::SendStream`. +pub trait SendStream: Send { + /// Send bytes to the stream. This takes a `Bytes` because iroh can directly use them. + /// + /// This method is not cancellation safe. Even if this does not resolve, some bytes may have been written when previously polled. + fn send_bytes(&mut self, bytes: Bytes) -> impl Future> + Send; + /// Send that sends a fixed sized buffer. + fn send(&mut self, buf: &[u8]) -> impl Future> + Send; + /// Sync the stream. Not needed for iroh, but needed for intermediate buffered streams such as compression. + fn sync(&mut self) -> impl Future> + Send; + /// Reset the stream with the given error code. + fn reset(&mut self, code: VarInt) -> io::Result<()>; + /// Wait for the stream to be stopped, returning the error code if it was. + fn stopped(&mut self) -> impl Future>> + Send; + /// Get the stream id. + fn id(&self) -> u64; +} + +/// An abstract `iroh::endpoint::RecvStream`. +pub trait RecvStream: Send { + /// Receive up to `len` bytes from the stream, directly into a `Bytes`. + fn recv_bytes(&mut self, len: usize) -> impl Future> + Send; + /// Receive exactly `len` bytes from the stream, directly into a `Bytes`. + /// + /// This will return an error if the stream ends before `len` bytes are read. + /// + /// Note that this is different from `recv_bytes`, which will return fewer bytes if the stream ends. + fn recv_bytes_exact(&mut self, len: usize) -> impl Future> + Send; + /// Receive exactly `target.len()` bytes from the stream. + fn recv_exact(&mut self, target: &mut [u8]) -> impl Future> + Send; + /// Stop the stream with the given error code. + fn stop(&mut self, code: VarInt) -> io::Result<()>; + /// Get the stream id. + fn id(&self) -> u64; +} + +impl SendStream for iroh::endpoint::SendStream { + async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> { + Ok(self.write_chunk(bytes).await?) + } + + async fn send(&mut self, buf: &[u8]) -> io::Result<()> { + Ok(self.write_all(buf).await?) + } + + async fn sync(&mut self) -> io::Result<()> { + Ok(()) + } + + fn reset(&mut self, code: VarInt) -> io::Result<()> { + Ok(self.reset(code)?) + } + + async fn stopped(&mut self) -> io::Result> { + Ok(self.stopped().await?) + } + + fn id(&self) -> u64 { + self.id().index() + } +} + +impl RecvStream for iroh::endpoint::RecvStream { + async fn recv_bytes(&mut self, len: usize) -> io::Result { + let mut buf = vec![0; len]; + match self.read_exact(&mut buf).await { + Err(ReadExactError::FinishedEarly(n)) => { + buf.truncate(n); + } + Err(ReadExactError::ReadError(e)) => { + return Err(e.into()); + } + Ok(()) => {} + }; + Ok(buf.into()) + } + + async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { + let mut buf = vec![0; len]; + self.read_exact(&mut buf).await.map_err(|e| match e { + ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""), + ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""), + ReadExactError::ReadError(e) => e.into(), + })?; + Ok(buf.into()) + } + + async fn recv_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.read_exact(buf).await.map_err(|e| match e { + ReadExactError::FinishedEarly(0) => io::Error::new(io::ErrorKind::UnexpectedEof, ""), + ReadExactError::FinishedEarly(_) => io::Error::new(io::ErrorKind::InvalidData, ""), + ReadExactError::ReadError(e) => e.into(), + }) + } + + fn stop(&mut self, code: VarInt) -> io::Result<()> { + Ok(self.stop(code)?) + } + + fn id(&self) -> u64 { + self.id().index() + } +} + +impl RecvStream for &mut R { + async fn recv_bytes(&mut self, len: usize) -> io::Result { + self.deref_mut().recv_bytes(len).await + } + + async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { + self.deref_mut().recv_bytes_exact(len).await + } + + async fn recv_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.deref_mut().recv_exact(buf).await + } + + fn stop(&mut self, code: VarInt) -> io::Result<()> { + self.deref_mut().stop(code) + } + + fn id(&self) -> u64 { + self.deref().id() + } +} + +impl SendStream for &mut W { + async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> { + self.deref_mut().send_bytes(bytes).await + } + + async fn send(&mut self, buf: &[u8]) -> io::Result<()> { + self.deref_mut().send(buf).await + } + + async fn sync(&mut self) -> io::Result<()> { + self.deref_mut().sync().await + } + + fn reset(&mut self, code: VarInt) -> io::Result<()> { + self.deref_mut().reset(code) + } + + async fn stopped(&mut self) -> io::Result> { + self.deref_mut().stopped().await + } + + fn id(&self) -> u64 { + self.deref().id() + } +} + +#[derive(Debug)] +pub struct AsyncReadRecvStream(R); + +/// This is a helper trait to work with [`AsyncReadRecvStream`]. If you have an +/// `AsyncRead + Unpin + Send`, you can implement these additional methods and wrap the result +/// in an `AsyncReadRecvStream` to get a `RecvStream` that reads from the underlying `AsyncRead`. +pub trait AsyncReadRecvStreamExtra: Send { + /// Get a mutable reference to the inner `AsyncRead`. + /// + /// Getting a reference is easier than implementing all methods on `AsyncWrite` with forwarders to the inner instance. + fn inner(&mut self) -> &mut (impl AsyncRead + Unpin + Send); + /// Stop the stream with the given error code. + fn stop(&mut self, code: VarInt) -> io::Result<()>; + /// A local unique identifier for the stream. + /// + /// This allows distinguishing between streams, but once the stream is closed, the id may be reused. + fn id(&self) -> u64; +} + +impl AsyncReadRecvStream { + pub fn new(inner: R) -> Self { + Self(inner) + } +} + +impl RecvStream for AsyncReadRecvStream { + async fn recv_bytes(&mut self, len: usize) -> io::Result { + let mut res = vec![0; len]; + let mut n = 0; + loop { + let read = self.0.inner().read(&mut res[n..]).await?; + if read == 0 { + res.truncate(n); + break; + } + n += read; + if n == len { + break; + } + } + Ok(res.into()) + } + + async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { + let mut res = vec![0; len]; + self.0.inner().read_exact(&mut res).await?; + Ok(res.into()) + } + + async fn recv_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.0.inner().read_exact(buf).await?; + Ok(()) + } + + fn stop(&mut self, code: VarInt) -> io::Result<()> { + self.0.stop(code) + } + + fn id(&self) -> u64 { + self.0.id() + } +} + +impl RecvStream for Bytes { + async fn recv_bytes(&mut self, len: usize) -> io::Result { + let n = len.min(self.len()); + let res = self.slice(..n); + *self = self.slice(n..); + Ok(res) + } + + async fn recv_bytes_exact(&mut self, len: usize) -> io::Result { + if self.len() < len { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + let res = self.slice(..len); + *self = self.slice(len..); + Ok(res) + } + + async fn recv_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + if self.len() < buf.len() { + return Err(io::ErrorKind::UnexpectedEof.into()); + } + buf.copy_from_slice(&self[..buf.len()]); + *self = self.slice(buf.len()..); + Ok(()) + } + + fn stop(&mut self, _code: VarInt) -> io::Result<()> { + Ok(()) + } + + fn id(&self) -> u64 { + 0 + } +} + +/// Utility to convert a [tokio::io::AsyncWrite] into an [SendStream]. +#[derive(Debug, Clone)] +pub struct AsyncWriteSendStream(W); + +/// This is a helper trait to work with [`AsyncWriteSendStream`]. +/// +/// If you have an `AsyncWrite + Unpin + Send`, you can implement these additional +/// methods and wrap the result in an `AsyncWriteSendStream` to get a `SendStream` +/// that writes to the underlying `AsyncWrite`. +pub trait AsyncWriteSendStreamExtra: Send { + /// Get a mutable reference to the inner `AsyncWrite`. + /// + /// Getting a reference is easier than implementing all methods on `AsyncWrite` with forwarders to the inner instance. + fn inner(&mut self) -> &mut (impl AsyncWrite + Unpin + Send); + /// Reset the stream with the given error code. + fn reset(&mut self, code: VarInt) -> io::Result<()>; + /// Wait for the stream to be stopped, returning the optional error code if it was. + fn stopped(&mut self) -> impl Future>> + Send; + /// A local unique identifier for the stream. + /// + /// This allows distinguishing between streams, but once the stream is closed, the id may be reused. + fn id(&self) -> u64; +} + +impl AsyncWriteSendStream { + pub fn new(inner: W) -> Self { + Self(inner) + } +} + +impl AsyncWriteSendStream { + pub fn into_inner(self) -> W { + self.0 + } +} + +impl SendStream for AsyncWriteSendStream { + async fn send_bytes(&mut self, bytes: Bytes) -> io::Result<()> { + self.0.inner().write_all(&bytes).await + } + + async fn send(&mut self, buf: &[u8]) -> io::Result<()> { + self.0.inner().write_all(buf).await + } + + async fn sync(&mut self) -> io::Result<()> { + self.0.inner().flush().await + } + + fn reset(&mut self, code: VarInt) -> io::Result<()> { + self.0.reset(code)?; + Ok(()) + } + + async fn stopped(&mut self) -> io::Result> { + let res = self.0.stopped().await?; + Ok(res) + } + + fn id(&self) -> u64 { + self.0.id() + } +} + +#[derive(Debug)] +pub struct RecvStreamAsyncStreamReader(R); + +impl RecvStreamAsyncStreamReader { + pub fn new(inner: R) -> Self { + Self(inner) + } + + pub fn into_inner(self) -> R { + self.0 + } +} + +impl AsyncStreamReader for RecvStreamAsyncStreamReader { + async fn read_bytes(&mut self, len: usize) -> io::Result { + self.0.recv_bytes_exact(len).await + } + + async fn read(&mut self) -> io::Result<[u8; L]> { + let mut buf = [0; L]; + self.0.recv_exact(&mut buf).await?; + Ok(buf) + } +} + +pub(crate) trait RecvStreamExt: RecvStream { + async fn expect_eof(&mut self) -> io::Result<()> { + match self.read_u8().await { + Ok(_) => Err(io::Error::new( + io::ErrorKind::InvalidData, + "unexpected data", + )), + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(()), + Err(e) => Err(e), + } + } + + async fn read_u8(&mut self) -> io::Result { + let mut buf = [0; 1]; + self.recv_exact(&mut buf).await?; + Ok(buf[0]) + } + + async fn read_to_end_as( + &mut self, + max_size: usize, + ) -> io::Result<(T, usize)> { + let data = self.recv_bytes(max_size).await?; + self.expect_eof().await?; + let value = postcard::from_bytes(&data) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok((value, data.len())) + } + + async fn read_length_prefixed( + &mut self, + max_size: usize, + ) -> io::Result { + let Some(n) = self.read_varint_u64().await? else { + return Err(io::ErrorKind::UnexpectedEof.into()); + }; + if n > max_size as u64 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "length prefix too large", + )); + } + let n = n as usize; + let data = self.recv_bytes(n).await?; + let value = postcard::from_bytes(&data) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(value) + } + + /// Reads a u64 varint from an AsyncRead source, using the Postcard/LEB128 format. + /// + /// In Postcard's varint format (LEB128): + /// - Each byte uses 7 bits for the value + /// - The MSB (most significant bit) of each byte indicates if there are more bytes (1) or not (0) + /// - Values are stored in little-endian order (least significant group first) + /// + /// Returns the decoded u64 value. + async fn read_varint_u64(&mut self) -> io::Result> { + let mut result: u64 = 0; + let mut shift: u32 = 0; + + loop { + // We can only shift up to 63 bits (for a u64) + if shift >= 64 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Varint is too large for u64", + )); + } + + // Read a single byte + let res = self.read_u8().await; + if shift == 0 { + if let Err(cause) = res { + if cause.kind() == io::ErrorKind::UnexpectedEof { + return Ok(None); + } else { + return Err(cause); + } + } + } + + let byte = res?; + + // Extract the 7 value bits (bits 0-6, excluding the MSB which is the continuation bit) + let value = (byte & 0x7F) as u64; + + // Add the bits to our result at the current shift position + result |= value << shift; + + // If the high bit is not set (0), this is the last byte + if byte & 0x80 == 0 { + break; + } + + // Move to the next 7 bits + shift += 7; + } + + Ok(Some(result)) + } +} + +impl RecvStreamExt for R {} + +pub(crate) trait SendStreamExt: SendStream { + async fn write_length_prefixed(&mut self, value: T) -> io::Result { + let size = postcard::experimental::serialized_size(&value) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + let mut buf = Vec::with_capacity(size + 9); + irpc::util::WriteVarintExt::write_length_prefixed(&mut buf, value)?; + let n = buf.len(); + self.send_bytes(buf.into()).await?; + Ok(n) + } +} + +impl SendStreamExt for W {}