diff --git a/Cargo.lock b/Cargo.lock index 1a4eeea..0d0c410 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -136,6 +136,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-compression" +version = "0.4.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a89bce6054c720275ac2432fbba080a66a2106a44a1b804553930ca6909f4e0" +dependencies = [ + "compression-codecs", + "compression-core", + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "async-trait" version = "0.1.89" @@ -344,6 +357,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "65193589c6404eb80b450d618eaf9a2cafaaafd57ecce47370519ef674a7bd44" dependencies = [ "find-msvc-tools", + "jobserver", + "libc", "shlex", ] @@ -479,6 +494,29 @@ dependencies = [ "memchr", ] +[[package]] +name = "compression-codecs" +version = "0.4.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef8a506ec4b81c460798f572caead636d57d3d7e940f998160f52bd254bf2d23" +dependencies = [ + "compression-core", + "zstd", + "zstd-safe", +] + +[[package]] +name = "compression-core" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e47641d3deaf41fb1538ac1f54735925e275eaf3bf4d55c81b137fba797e5cbb" + +[[package]] +name = "concat_const" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60c92cd5ec953d0542f48d2a90a25aa2828ab1c03217c1ca077000f3af15997d" + [[package]] name = "console" version = "0.15.11" @@ -2081,6 +2119,16 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.3", + "libc", +] + [[package]] name = "js-sys" version = "0.3.79" @@ -2782,6 +2830,12 @@ dependencies = [ "spki", ] +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "pnet_base" version = "0.34.0" @@ -3532,7 +3586,9 @@ name = "sendme" version = "0.29.0" dependencies = [ "anyhow", + "async-compression", "clap", + "concat_const", "console", "crossterm", "data-encoding", @@ -5396,3 +5452,31 @@ dependencies = [ "quote", "syn 2.0.106", ] + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.16+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index 696b9a8..9aa4f84 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,8 @@ walkdir = "2.4.0" data-encoding = "2.6.0" n0-future = "0.1.2" hex = "0.4.3" +async-compression = { version = "0.4.25", features = ["tokio", "zstd"], optional = true } +concat_const = { version = "0.2.0", optional = true} crossterm = { version = "0.29.0", features = [ "event-stream", "osc52", @@ -56,8 +58,9 @@ serde_json = "1.0.108" tempfile = "3.8.1" [features] +zstd = ["async-compression","concat_const"] clipboard = ["dep:crossterm", "dep:windows-sys", "dep:libc"] -default = ["clipboard"] +default = ["clipboard","zstd"] [profile.release] panic = "abort" diff --git a/src/main.rs b/src/main.rs index fdee664..57ba85c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,7 @@ use std::{ collections::BTreeMap, - fmt::{Display, Formatter}, + fmt::{Debug, Display, Formatter}, net::{SocketAddrV4, SocketAddrV6}, path::{Component, Path, PathBuf}, str::FromStr, @@ -21,10 +21,18 @@ use futures_buffered::BufferedStreamExt; use indicatif::{ HumanBytes, HumanDuration, MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle, }; + +#[cfg(feature = "zstd")] +use iroh::endpoint::ConnectOptions; + use iroh::{ discovery::{dns::DnsDiscovery, pkarr::PkarrPublisher}, Endpoint, NodeAddr, RelayMode, RelayUrl, SecretKey, }; + +#[cfg(feature = "zstd")] +use iroh::{endpoint::Connection, protocol::ProtocolHandler}; + use iroh_blobs::{ api::{ blobs::{ @@ -44,7 +52,23 @@ use iroh_blobs::{ ticket::BlobTicket, BlobFormat, BlobsProtocol, Hash, }; + +#[cfg(feature = "zstd")] +use iroh_blobs::{ + api::remote::GetStreamPair, + get::StreamPair, + provider::{ + events::{ClientConnected, HasErrorCode}, + handle_stream, + }, + util::{RecvStream, SendStream}, +}; + use n0_future::{task::AbortOnDropHandle, FuturesUnordered, StreamExt}; + +#[cfg(feature = "zstd")] +use n0_future::io; + use rand::Rng; use serde::{Deserialize, Serialize}; use tokio::{select, sync::mpsc}; @@ -175,7 +199,7 @@ impl Display for RelayModeOption { match self { Self::Disabled => f.write_str("disabled"), Self::Default => f.write_str("default"), - Self::Custom(url) => url.fmt(f), + Self::Custom(url) => std::fmt::Display::fmt(&url, f), } } } @@ -221,6 +245,16 @@ pub struct SendArgs { #[cfg(feature = "clipboard")] #[clap(short = 'c', long)] pub clipboard: bool, + + /// Use zstd to compress outgoing and decompress incoming data + #[cfg(feature = "zstd")] + #[clap(short = 'z', long)] + pub zstd: bool, + + /// Compression level for zstd + #[cfg(feature = "zstd")] + #[clap(short = 'q', long, default_value_t = 3, requires("zstd"))] + pub compression_quality: u8, } #[derive(Parser, Debug)] @@ -302,6 +336,160 @@ fn validate_path_component(component: &str) -> anyhow::Result<()> { Ok(()) } +#[cfg(feature = "zstd")] +trait Compression: Clone + Send + Sync + Debug + 'static { + const ALPN: &'static [u8]; + fn recv_stream( + &self, + stream: iroh::endpoint::RecvStream, + ) -> impl iroh_blobs::util::RecvStream + Sync + 'static; + fn send_stream( + &self, + stream: iroh::endpoint::SendStream, + compression_level: u8, + ) -> impl iroh_blobs::util::SendStream + Sync + 'static; +} + +#[cfg(feature = "zstd")] +mod zstd { + use std::io; + + use async_compression::tokio::{bufread::ZstdDecoder, write::ZstdEncoder}; + use async_compression::Level; + use iroh::endpoint::VarInt; + use iroh_blobs::util::{ + AsyncReadRecvStream, AsyncReadRecvStreamExtra, AsyncWriteSendStream, + AsyncWriteSendStreamExtra, + }; + use tokio::io::{AsyncRead, AsyncWrite, BufReader}; + + struct SendStream(ZstdEncoder); + + impl SendStream { + pub fn new( + inner: iroh::endpoint::SendStream, + compression_level: u8, + ) -> AsyncWriteSendStream { + let c_level = compression_level.clamp(1, 22); + AsyncWriteSendStream::new(Self(ZstdEncoder::with_quality( + inner, + Level::Precise(c_level as _), + ))) + } + } + + impl AsyncWriteSendStreamExtra for SendStream { + fn inner(&mut self) -> &mut (impl AsyncWrite + Unpin + Send) { + &mut self.0 + } + + fn reset(&mut self, code: VarInt) -> io::Result<()> { + Ok(self.0.get_mut().reset(code)?) + } + + async fn stopped(&mut self) -> io::Result> { + Ok(self.0.get_mut().stopped().await?) + } + + fn id(&self) -> u64 { + self.0.get_ref().id().index() + } + } + + struct RecvStream(ZstdDecoder>); + + impl RecvStream { + pub fn new(inner: iroh::endpoint::RecvStream) -> AsyncReadRecvStream { + AsyncReadRecvStream::new(Self(ZstdDecoder::new(BufReader::new(inner)))) + } + } + + impl AsyncReadRecvStreamExtra for RecvStream { + fn inner(&mut self) -> &mut (impl AsyncRead + Unpin + Send) { + &mut self.0 + } + + fn stop(&mut self, code: VarInt) -> io::Result<()> { + Ok(self.0.get_mut().get_mut().stop(code)?) + } + + fn id(&self) -> u64 { + self.0.get_ref().get_ref().id().index() + } + } + + #[derive(Debug, Clone)] + pub struct Compression; + + impl super::Compression for Compression { + const ALPN: &[u8] = concat_const::concat_bytes!(b"zstd/", iroh_blobs::ALPN); + fn recv_stream( + &self, + stream: iroh::endpoint::RecvStream, + ) -> impl iroh_blobs::util::RecvStream + Sync + 'static { + RecvStream::new(stream) + } + fn send_stream( + &self, + stream: iroh::endpoint::SendStream, + compression_level: u8, + ) -> impl iroh_blobs::util::SendStream + Sync + 'static { + SendStream::new(stream, compression_level) + } + } +} + +#[cfg(feature = "zstd")] +#[derive(Debug, Clone)] +struct CompressedBlobsProtocol { + store: Store, + events: EventSender, + compression: C, + compression_level: u8, +} + +#[cfg(feature = "zstd")] +impl CompressedBlobsProtocol { + fn new(store: &Store, events: EventSender, compression: C, compression_level: u8) -> Self { + Self { + store: store.clone(), + events, + compression, + compression_level, + } + } +} + +#[cfg(feature = "zstd")] +impl ProtocolHandler for CompressedBlobsProtocol { + async fn accept( + &self, + connection: iroh::endpoint::Connection, + ) -> std::result::Result<(), iroh::protocol::AcceptError> { + let connection_id = connection.stable_id() as u64; + if let Err(cause) = self + .events + .client_connected(|| ClientConnected { + connection_id, + node_id: connection.remote_node_id().ok(), + }) + .await + { + connection.close(cause.code(), cause.reason()); + // debug!("closing connection: {cause}"); + return Ok(()); + } + while let Ok((send, recv)) = connection.accept_bi().await { + let send = self.compression.send_stream(send, self.compression_level); + let recv = self.compression.recv_stream(recv); + let store = self.store.clone(); + let pair = provider::StreamPair::new(connection_id, recv, send, self.events.clone()); + tokio::spawn(handle_stream(pair, store)); + } + Ok(()) + } +} + /// This function converts an already canonicalized path to a string. /// /// If `must_be_relative` is true, the function will fail if any component of the path is @@ -634,9 +822,26 @@ async fn send(args: SendArgs) -> anyhow::Result<()> { eprintln!("using secret key {secret_key}"); } // create a magicsocket endpoint + let alpn: Vec = { + #[cfg(feature = "zstd")] + { + if args.zstd { + zstd::Compression::ALPN.to_vec() + } else { + iroh_blobs::protocol::ALPN.to_vec() + } + } + + #[cfg(not(feature = "zstd"))] + { + // When the feature isn't enabled, we just fall back + iroh_blobs::protocol::ALPN.to_vec() + } + }; + let relay_mode: RelayMode = args.common.relay.into(); let mut builder = Endpoint::builder() - .alpns(vec![iroh_blobs::protocol::ALPN.to_vec()]) + .alpns(vec![alpn]) .secret_key(secret_key) .relay_mode(relay_mode.clone()); if args.ticket_type == AddrInfoOptions::Id { @@ -689,24 +894,52 @@ async fn send(args: SendArgs) -> anyhow::Result<()> { }; mp.set_draw_target(draw_target); let store = FsStore::load(&blobs_data_dir2).await?; - let blobs = BlobsProtocol::new( - &store, - Some(EventSender::new( - progress_tx, - EventMask { - connected: ConnectMode::Notify, - get: provider::events::RequestMode::NotifyLog, - ..EventMask::DEFAULT - }, - )), + + let event_sender = EventSender::new( + progress_tx, + EventMask { + connected: ConnectMode::Notify, + get: provider::events::RequestMode::NotifyLog, + ..EventMask::DEFAULT + }, ); - let import_result = import(path2, blobs.store(), &mut mp).await?; - let dt = t0.elapsed(); + let (router, import_result, dt) = { + #[cfg(feature = "zstd")] + { + if args.zstd { + let compression = zstd::Compression; + let blobs = CompressedBlobsProtocol::new( + &store, + event_sender, + compression, + args.compression_quality, + ); + let import_result = import(path2, &blobs.store, &mut mp).await?; + let router = iroh::protocol::Router::builder(endpoint) + .accept(zstd::Compression::ALPN, blobs.clone()) + .spawn(); + (router, import_result, t0.elapsed()) + } else { + let blobs = BlobsProtocol::new(&store, Some(event_sender)); + let import_result = import(path2, blobs.store(), &mut mp).await?; + let router = iroh::protocol::Router::builder(endpoint) + .accept(iroh_blobs::ALPN, blobs.clone()) + .spawn(); + (router, import_result, t0.elapsed()) + } + } - let router = iroh::protocol::Router::builder(endpoint) - .accept(iroh_blobs::ALPN, blobs.clone()) - .spawn(); + #[cfg(not(feature = "zstd"))] + { + let blobs = BlobsProtocol::new(&store, Some(event_sender)); + let import_result = import(path2, blobs.store(), &mut mp).await?; + let router = iroh::protocol::Router::builder(endpoint) + .accept(iroh_blobs::ALPN, blobs.clone()) + .spawn(); + (router, import_result, t0.elapsed()) + } + }; // wait for the endpoint to figure out its address before making a ticket let ep = router.endpoint(); @@ -986,6 +1219,22 @@ fn show_get_error(e: GetError) -> GetError { e } +#[cfg(feature = "zstd")] +struct ZstdConn { + connection: Connection, +} + +#[cfg(feature = "zstd")] +impl GetStreamPair for ZstdConn { + async fn open_stream_pair(self) -> io::Result> { + let connection_id = self.connection.stable_id() as u64; + let (send, recv) = self.connection.open_bi().await?; + let send = zstd::Compression.send_stream(send, 3); + let recv = zstd::Compression.recv_stream(recv); + Ok(StreamPair::new(connection_id, recv, send)) + } +} + async fn receive(args: ReceiveArgs) -> anyhow::Result<()> { let ticket = args.ticket; let addr = ticket.node_addr().clone(); @@ -1026,34 +1275,81 @@ async fn receive(args: ReceiveArgs) -> anyhow::Result<()> { let (stats, total_files, payload_size) = if !local.is_complete() { trace!("{} not complete", hash_and_format.hash); let cp = mp.add(make_connect_progress()); - let connection = endpoint.connect(addr, iroh_blobs::protocol::ALPN).await?; - cp.finish_and_clear(); - let sp = mp.add(make_get_sizes_progress()); - let (_hash_seq, sizes) = - get_hash_seq_and_sizes(&connection, &hash_and_format.hash, 1024 * 1024 * 32, None) - .await - .map_err(show_get_error)?; - sp.finish_and_clear(); - let total_size = sizes.iter().copied().sum::(); - let payload_size = sizes.iter().skip(2).copied().sum::(); - let total_files = (sizes.len().saturating_sub(1)) as u64; - eprintln!( - "getting collection {} {} files, {}", - print_hash(&ticket.hash(), args.common.format), - total_files, - HumanBytes(payload_size) - ); - // print the details of the collection only in verbose mode - if args.common.verbose > 0 { + + #[cfg(feature = "zstd")] + let options = + ConnectOptions::new().with_additional_alpns(vec![zstd::Compression::ALPN.to_vec()]); + + #[cfg(not(feature = "zstd"))] + let options = Default::default(); + + let mut connecting = endpoint + .connect_with_opts(addr, iroh_blobs::ALPN, options) + .await?; + + let using_zstd: bool = match connecting.alpn().await { + #[cfg(feature = "zstd")] + Ok(alpn_vec) => alpn_vec == zstd::Compression::ALPN.to_vec(), + #[cfg(not(feature = "zstd"))] + Ok(_) => false, + + Err(e) => { + anyhow::bail!( + "This build of sendme does not support receiving with compression: {}", + e + ); + } + }; + + let connection = connecting.await?; + let (total_size, payload_size, total_files) = if !using_zstd { + cp.finish_and_clear(); + let sp = mp.add(make_get_sizes_progress()); + let (_hash_seq, sizes) = get_hash_seq_and_sizes( + &connection, + &hash_and_format.hash, + 1024 * 1024 * 32, + None, + ) + .await + .map_err(show_get_error)?; + sp.finish_and_clear(); + let total_size = sizes.iter().copied().sum::(); + let payload_size = sizes.iter().skip(2).copied().sum::(); + let total_files = (sizes.len().saturating_sub(1)) as u64; eprintln!( - "getting {} blobs in total, {}", - total_files + 1, - HumanBytes(total_size) + "getting collection {} {} files, {}", + print_hash(&ticket.hash(), args.common.format), + total_files, + HumanBytes(payload_size) ); - } + // print the details of the collection only in verbose mode + if args.common.verbose > 0 { + eprintln!( + "getting {} blobs in total, {}", + total_files + 1, + HumanBytes(total_size) + ); + } + (total_size, payload_size, total_files) + } else { + (0, 0, 0) + }; + let (tx, rx) = mpsc::channel(32); let local_size = local.local_bytes(); + + #[cfg(feature = "zstd")] + let get = if using_zstd { + let zstd_conn = ZstdConn { connection }; + db.remote().execute_get(zstd_conn, local.missing()) + } else { + db.remote().execute_get(connection, local.missing()) + }; + + #[cfg(not(feature = "zstd"))] let get = db.remote().execute_get(connection, local.missing()); + let task = tokio::spawn(show_download_progress( mp.clone(), rx,