From 58f5c77728d7e5cebbcd2d14d2b69e7c276533aa Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 21 Aug 2025 11:57:35 +0200 Subject: [PATCH 01/16] Replace hacky connection pool with the one from https://github.com/n0-computer/iroh-experiments/pull/36 --- Cargo.lock | 33 ++- Cargo.toml | 2 +- src/api/downloader.rs | 168 +------------ src/api/remote.rs | 8 +- src/util.rs | 2 + src/util/connection_pool.rs | 465 ++++++++++++++++++++++++++++++++++++ 6 files changed, 507 insertions(+), 171 deletions(-) create mode 100644 src/util/connection_pool.rs diff --git a/Cargo.lock b/Cargo.lock index b966614a4..19c59bac3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1683,7 +1683,7 @@ dependencies = [ "iroh-quinn-proto", "iroh-quinn-udp", "iroh-relay", - "n0-future", + "n0-future 0.1.3", "n0-snafu", "n0-watcher", "nested_enum_utils", @@ -1758,7 +1758,7 @@ dependencies = [ "iroh-quinn", "iroh-test", "irpc", - "n0-future", + "n0-future 0.2.0", "n0-snafu", "nested_enum_utils", "postcard", @@ -1900,7 +1900,7 @@ dependencies = [ "iroh-quinn", "iroh-quinn-proto", "lru", - "n0-future", + "n0-future 0.1.3", "n0-snafu", "nested_enum_utils", "num_enum", @@ -1951,7 +1951,7 @@ dependencies = [ "futures-util", "iroh-quinn", "irpc-derive", - "n0-future", + "n0-future 0.1.3", "postcard", "rcgen", "rustls", @@ -2173,6 +2173,27 @@ dependencies = [ "web-time", ] +[[package]] +name = "n0-future" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89d7dd42bd0114c9daa9c4f2255d692a73bba45767ec32cf62892af6fe5d31f6" +dependencies = [ + "cfg_aliases", + "derive_more 1.0.0", + "futures-buffered", + "futures-lite", + "futures-util", + "js-sys", + "pin-project", + "send_wrapper", + "tokio", + "tokio-util", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-time", +] + [[package]] name = "n0-snafu" version = "0.2.1" @@ -2193,7 +2214,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c31462392a10d5ada4b945e840cbec2d5f3fee752b96c4b33eb41414d8f45c2a" dependencies = [ "derive_more 1.0.0", - "n0-future", + "n0-future 0.1.3", "snafu", ] @@ -2319,7 +2340,7 @@ dependencies = [ "iroh-quinn-udp", "js-sys", "libc", - "n0-future", + "n0-future 0.1.3", "n0-watcher", "nested_enum_utils", "netdev", diff --git a/Cargo.toml b/Cargo.toml index 3f9f47a9d..bcd5f42d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ bytes = { version = "1", features = ["serde"] } derive_more = { version = "2.0.1", features = ["from", "try_from", "into", "debug", "display", "deref", "deref_mut"] } futures-lite = "2.6.0" quinn = { package = "iroh-quinn", version = "0.14.0" } -n0-future = "0.1.2" +n0-future = "0.2.0" n0-snafu = "0.2.0" range-collections = { version = "0.4.6", features = ["serde"] } redb = { version = "=2.4" } diff --git a/src/api/downloader.rs b/src/api/downloader.rs index 678a8c6ad..52192deb3 100644 --- a/src/api/downloader.rs +++ b/src/api/downloader.rs @@ -23,7 +23,10 @@ use tracing::{info, instrument::Instrument, warn}; use super::{remote::GetConnection, Store}; use crate::{ protocol::{GetManyRequest, GetRequest}, - util::sink::{Drain, IrpcSenderRefSink, Sink, TokioMpscSenderSink}, + util::{ + connection_pool::ConnectionPool, + sink::{Drain, IrpcSenderRefSink, Sink, TokioMpscSenderSink}, + }, BlobFormat, Hash, HashAndFormat, }; @@ -69,7 +72,7 @@ impl DownloaderActor { fn new(store: Store, endpoint: Endpoint) -> Self { Self { store, - pool: ConnectionPool::new(endpoint, crate::ALPN.to_vec()), + pool: ConnectionPool::new(endpoint, crate::ALPN, Default::default()), tasks: JoinSet::new(), running: HashSet::new(), } @@ -414,90 +417,6 @@ async fn split_request<'a>( }) } -#[derive(Debug)] -struct ConnectionPoolInner { - alpn: Vec, - endpoint: Endpoint, - connections: Mutex>>>, - retry_delay: Duration, - connect_timeout: Duration, -} - -#[derive(Debug, Clone)] -struct ConnectionPool(Arc); - -#[derive(Debug, Default)] -enum SlotState { - #[default] - Initial, - Connected(Connection), - AttemptFailed(SystemTime), - #[allow(dead_code)] - Evil(String), -} - -impl ConnectionPool { - fn new(endpoint: Endpoint, alpn: Vec) -> Self { - Self( - ConnectionPoolInner { - endpoint, - alpn, - connections: Default::default(), - retry_delay: Duration::from_secs(5), - connect_timeout: Duration::from_secs(2), - } - .into(), - ) - } - - pub fn alpn(&self) -> &[u8] { - &self.0.alpn - } - - pub fn endpoint(&self) -> &Endpoint { - &self.0.endpoint - } - - pub fn retry_delay(&self) -> Duration { - self.0.retry_delay - } - - fn dial(&self, id: NodeId) -> DialNode { - DialNode { - pool: self.clone(), - id, - } - } - - #[allow(dead_code)] - async fn mark_evil(&self, id: NodeId, reason: String) { - let slot = self - .0 - .connections - .lock() - .await - .entry(id) - .or_default() - .clone(); - let mut t = slot.lock().await; - *t = SlotState::Evil(reason) - } - - #[allow(dead_code)] - async fn mark_closed(&self, id: NodeId) { - let slot = self - .0 - .connections - .lock() - .await - .entry(id) - .or_default() - .clone(); - let mut t = slot.lock().await; - *t = SlotState::Initial - } -} - /// Execute a get request sequentially for multiple providers. /// /// It will try each provider in order @@ -526,13 +445,13 @@ async fn execute_get( request: request.clone(), }) .await?; - let mut conn = pool.dial(provider); + let mut conn = pool.connect(provider); let local = remote.local_for_request(request.clone()).await?; if local.is_complete() { return Ok(()); } let local_bytes = local.local_bytes(); - let Ok(conn) = conn.connection().await else { + let Ok(conn) = conn.await else { progress .send(DownloadProgessItem::ProviderFailed { id: provider, @@ -543,7 +462,7 @@ async fn execute_get( }; match remote .execute_get_sink( - conn, + &conn, local.missing(), (&mut progress).with_map(move |x| DownloadProgessItem::Progress(x + local_bytes)), ) @@ -571,77 +490,6 @@ async fn execute_get( bail!("Unable to download {}", request.hash); } -#[derive(Debug, Clone)] -struct DialNode { - pool: ConnectionPool, - id: NodeId, -} - -impl DialNode { - async fn connection_impl(&self) -> anyhow::Result { - info!("Getting connection for node {}", self.id); - let slot = self - .pool - .0 - .connections - .lock() - .await - .entry(self.id) - .or_default() - .clone(); - info!("Dialing node {}", self.id); - let mut guard = slot.lock().await; - match guard.deref() { - SlotState::Connected(conn) => { - return Ok(conn.clone()); - } - SlotState::AttemptFailed(time) => { - let elapsed = time.elapsed().unwrap_or_default(); - if elapsed <= self.pool.retry_delay() { - bail!( - "Connection attempt failed {} seconds ago", - elapsed.as_secs_f64() - ); - } - } - SlotState::Evil(reason) => { - bail!("Node is banned due to evil behavior: {reason}"); - } - SlotState::Initial => {} - } - let res = self - .pool - .endpoint() - .connect(self.id, self.pool.alpn()) - .timeout(self.pool.0.connect_timeout) - .await; - match res { - Ok(Ok(conn)) => { - info!("Connected to node {}", self.id); - *guard = SlotState::Connected(conn.clone()); - Ok(conn) - } - Ok(Err(e)) => { - warn!("Failed to connect to node {}: {}", self.id, e); - *guard = SlotState::AttemptFailed(SystemTime::now()); - Err(e.into()) - } - Err(e) => { - warn!("Failed to connect to node {}: {}", self.id, e); - *guard = SlotState::AttemptFailed(SystemTime::now()); - bail!("Failed to connect to node: {}", e); - } - } - } -} - -impl GetConnection for DialNode { - fn connection(&mut self) -> impl Future> + '_ { - let this = self.clone(); - async move { this.connection_impl().await } - } -} - /// Trait for pluggable content discovery strategies. pub trait ContentDiscovery: Debug + Send + Sync + 'static { fn find_providers(&self, hash: HashAndFormat) -> n0_future::stream::Boxed; diff --git a/src/api/remote.rs b/src/api/remote.rs index 47c3eea27..945bb7332 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -518,7 +518,7 @@ impl Remote { .connection() .await .map_err(|e| LocalFailureSnafu.into_error(e.into()))?; - let stats = self.execute_get_sink(conn, request, progress).await?; + let stats = self.execute_get_sink(&conn, request, progress).await?; Ok(stats) } @@ -637,7 +637,7 @@ impl Remote { .with_map_err(io::Error::other); let this = self.clone(); let fut = async move { - let res = this.execute_get_sink(conn, request, sink).await.into(); + let res = this.execute_get_sink(&conn, request, sink).await.into(); tx2.send(res).await.ok(); }; GetProgress { @@ -656,13 +656,13 @@ impl Remote { /// This will return the stats of the download. pub(crate) async fn execute_get_sink( &self, - conn: Connection, + conn: &Connection, request: GetRequest, mut progress: impl Sink, ) -> GetResult { let store = self.store(); let root = request.hash; - let start = crate::get::fsm::start(conn, request, Default::default()); + let start = crate::get::fsm::start(conn.clone(), request, Default::default()); let connected = start.next().await?; trace!("Getting header"); // read the header diff --git a/src/util.rs b/src/util.rs index 7b9ad4e6e..f3e493934 100644 --- a/src/util.rs +++ b/src/util.rs @@ -4,7 +4,9 @@ use bao_tree::{io::round_up_to_chunks, ChunkNum, ChunkRanges}; use range_collections::{range_set::RangeSetEntry, RangeSet2}; pub mod channel; +pub(crate) mod connection_pool; pub(crate) mod temp_tag; + pub mod serde { // Module that handles io::Error serialization/deserialization pub mod io_error_serde { diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs new file mode 100644 index 000000000..1ae0a918c --- /dev/null +++ b/src/util/connection_pool.rs @@ -0,0 +1,465 @@ +//! A simple iroh connection pool +//! +//! Entry point is [`ConnectionPool`]. You create a connection pool for a specific +//! ALPN and [`Options`]. Then the pool will manage connections for you. +//! +//! Access to connections is via the [`ConnectionPool::connect`] method, which +//! gives you access to a connection if possible. +//! +//! It is important that you use the connection only in the future passed to +//! connect, and don't clone it out of the future. +use std::{ + collections::{HashMap, VecDeque}, + ops::Deref, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::Duration, +}; + +use iroh::{ + endpoint::{ConnectError, Connection}, + Endpoint, NodeId, +}; +use n0_future::{ + future::{self}, + FuturesUnordered, MaybeFuture, Stream, StreamExt, +}; +use snafu::Snafu; +use tokio::sync::{ + mpsc::{self, error::SendError as TokioSendError}, + oneshot, Notify, +}; +use tokio_util::time::FutureExt as TimeFutureExt; +use tracing::{debug, error, trace}; + +/// Configuration options for the connection pool +#[derive(Debug, Clone, Copy)] +pub struct Options { + pub idle_timeout: Duration, + pub connect_timeout: Duration, + pub max_connections: usize, +} + +impl Default for Options { + fn default() -> Self { + Self { + idle_timeout: Duration::from_secs(5), + connect_timeout: Duration::from_secs(1), + max_connections: 1024, + } + } +} + +/// A reference to a connection that is owned by a connection pool. +#[derive(Debug)] +pub struct ConnectionRef { + connection: iroh::endpoint::Connection, + _permit: OneConnection, +} + +impl Deref for ConnectionRef { + type Target = iroh::endpoint::Connection; + + fn deref(&self) -> &Self::Target { + &self.connection + } +} + +impl ConnectionRef { + fn new(connection: iroh::endpoint::Connection, counter: OneConnection) -> Self { + Self { + connection, + _permit: counter, + } + } +} + +/// Error when a connection can not be acquired +/// +/// This includes the normal iroh connection errors as well as pool specific +/// errors such as timeouts and connection limits. +#[derive(Debug, Clone, Snafu)] +#[snafu(module)] +pub enum PoolConnectError { + /// Connection pool is shut down + Shutdown, + /// Timeout during connect + Timeout, + /// Too many connections + TooManyConnections, + /// Error during connect + ConnectError { source: Arc }, +} + +impl From for PoolConnectError { + fn from(e: ConnectError) -> Self { + PoolConnectError::ConnectError { + source: Arc::new(e), + } + } +} + +/// Error when calling a fn on the [`ConnectionPool`]. +/// +/// The only thing that can go wrong is that the connection pool is shut down. +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum ConnectionPoolError { + /// The connection pool has been shut down + Shutdown, +} + +pub type PoolConnectResult = std::result::Result; + +enum ActorMessage { + RequestRef(RequestRef), + ConnectionIdle { id: NodeId }, + ConnectionShutdown { id: NodeId }, +} + +struct RequestRef { + id: NodeId, + tx: oneshot::Sender>, +} + +struct Context { + options: Options, + endpoint: Endpoint, + owner: ConnectionPool, + alpn: Vec, +} + +impl Context { + async fn run_connection_actor( + self: Arc, + node_id: NodeId, + mut rx: mpsc::Receiver, + ) { + let context = self; + + // Connect to the node + let state = context + .endpoint + .connect(node_id, &context.alpn) + .timeout(context.options.connect_timeout) + .await + .map_err(|_| PoolConnectError::Timeout) + .and_then(|r| r.map_err(PoolConnectError::from)); + if let Err(e) = &state { + debug!(%node_id, "Failed to connect {e:?}, requesting shutdown"); + if context.owner.close(node_id).await.is_err() { + return; + } + } + let counter = ConnectionCounter::new(); + let idle_timer = MaybeFuture::default(); + let idle_stream = counter.clone().idle_stream(); + + tokio::pin!(idle_timer, idle_stream); + + loop { + tokio::select! { + biased; + + // Handle new work + handler = rx.recv() => { + match handler { + Some(RequestRef { id, tx }) => { + assert!(id == node_id, "Not for me!"); + match &state { + Ok(state) => { + let res = ConnectionRef::new(state.clone(), counter.get_one()); + + // clear the idle timer + idle_timer.as_mut().set_none(); + tx.send(Ok(res)).ok(); + } + Err(cause) => { + tx.send(Err(cause.clone())).ok(); + } + } + } + None => { + // Channel closed - finish remaining tasks and exit + break; + } + } + } + + _ = idle_stream.next() => { + if !counter.is_idle() { + continue; + }; + // notify the pool that we are idle. + trace!(%node_id, "Idle"); + if context.owner.idle(node_id).await.is_err() { + // If we can't notify the pool, we are shutting down + break; + } + // set the idle timer + idle_timer.as_mut().set_future(tokio::time::sleep(context.options.idle_timeout)); + } + + // Idle timeout - request shutdown + _ = &mut idle_timer => { + trace!(%node_id, "Idle timer expired, requesting shutdown"); + context.owner.close(node_id).await.ok(); + // Don't break here - wait for main actor to close our channel + } + } + } + + if let Ok(connection) = state { + let reason = if counter.is_idle() { b"idle" } else { b"drop" }; + connection.close(0u32.into(), reason); + } + + trace!(%node_id, "Connection actor shutting down"); + } +} + +struct Actor { + rx: mpsc::Receiver, + connections: HashMap>, + context: Arc, + // idle set (most recent last) + // todo: use a better data structure if this becomes a performance issue + idle: VecDeque, + // per connection tasks + tasks: FuturesUnordered>, +} + +impl Actor { + pub fn new( + endpoint: Endpoint, + alpn: &[u8], + options: Options, + ) -> (Self, mpsc::Sender) { + let (tx, rx) = mpsc::channel(100); + ( + Self { + rx, + connections: HashMap::new(), + idle: VecDeque::new(), + context: Arc::new(Context { + options, + alpn: alpn.to_vec(), + endpoint, + owner: ConnectionPool { tx: tx.clone() }, + }), + tasks: FuturesUnordered::new(), + }, + tx, + ) + } + + fn add_idle(&mut self, id: NodeId) { + self.remove_idle(id); + self.idle.push_back(id); + } + + fn remove_idle(&mut self, id: NodeId) { + self.idle.retain(|&x| x != id); + } + + fn pop_oldest_idle(&mut self) -> Option { + self.idle.pop_front() + } + + fn remove_connection(&mut self, id: NodeId) { + self.connections.remove(&id); + self.remove_idle(id); + } + + async fn handle_msg(&mut self, msg: ActorMessage) { + match msg { + ActorMessage::RequestRef(mut msg) => { + let id = msg.id; + self.remove_idle(id); + // Try to send to existing connection actor + if let Some(conn_tx) = self.connections.get(&id) { + if let Err(TokioSendError(e)) = conn_tx.send(msg).await { + msg = e; + } else { + return; + } + // Connection actor died, remove it + self.remove_connection(id); + } + + // No connection actor or it died - check limits + if self.connections.len() >= self.context.options.max_connections { + if let Some(idle) = self.pop_oldest_idle() { + // remove the oldest idle connection to make room for one more + trace!("removing oldest idle connection {}", idle); + self.connections.remove(&idle); + } else { + msg.tx.send(Err(PoolConnectError::TooManyConnections)).ok(); + return; + } + } + let (conn_tx, conn_rx) = mpsc::channel(100); + self.connections.insert(id, conn_tx.clone()); + + let context = self.context.clone(); + + self.tasks + .push(Box::pin(context.run_connection_actor(id, conn_rx))); + + // Send the handler to the new actor + if conn_tx.send(msg).await.is_err() { + error!(%id, "Failed to send handler to new connection actor"); + self.connections.remove(&id); + } + } + ActorMessage::ConnectionIdle { id } => { + self.add_idle(id); + trace!(%id, "connection idle"); + } + ActorMessage::ConnectionShutdown { id } => { + // Remove the connection from our map - this closes the channel + self.remove_connection(id); + trace!(%id, "removed connection"); + } + } + } + + pub async fn run(mut self) { + loop { + tokio::select! { + biased; + + msg = self.rx.recv() => { + if let Some(msg) = msg { + self.handle_msg(msg).await; + } else { + break; + } + } + + _ = self.tasks.next(), if !self.tasks.is_empty() => {} + } + } + } +} + +/// A connection pool +#[derive(Debug, Clone)] +pub struct ConnectionPool { + tx: mpsc::Sender, +} + +impl ConnectionPool { + pub fn new(endpoint: Endpoint, alpn: &[u8], options: Options) -> Self { + let (actor, tx) = Actor::new(endpoint, alpn, options); + + // Spawn the main actor + tokio::spawn(actor.run()); + + Self { tx } + } + + /// Returns either a fresh connection or a reference to an existing one. + /// + /// This is guaranteed to return after approximately [Options::connect_timeout] + /// with either an error or a connection. + pub async fn connect( + &self, + id: NodeId, + ) -> std::result::Result { + let (tx, rx) = oneshot::channel(); + self.tx + .send(ActorMessage::RequestRef(RequestRef { id, tx })) + .await + .map_err(|_| PoolConnectError::Shutdown)?; + rx.await.map_err(|_| PoolConnectError::Shutdown)? + } + + /// Close an existing connection, if it exists + /// + /// This will finish pending tasks and close the connection. New tasks will + /// get a new connection if they are submitted after this call + pub async fn close(&self, id: NodeId) -> std::result::Result<(), ConnectionPoolError> { + self.tx + .send(ActorMessage::ConnectionShutdown { id }) + .await + .map_err(|_| ConnectionPoolError::Shutdown)?; + Ok(()) + } + + /// Notify the connection pool that a connection is idle. + /// + /// Should only be called from connection handlers. + pub(crate) async fn idle(&self, id: NodeId) -> std::result::Result<(), ConnectionPoolError> { + self.tx + .send(ActorMessage::ConnectionIdle { id }) + .await + .map_err(|_| ConnectionPoolError::Shutdown)?; + Ok(()) + } +} + +#[derive(Debug)] +struct ConnectionCounterInner { + count: AtomicUsize, + notify: Notify, +} + +#[derive(Debug, Clone)] +struct ConnectionCounter { + inner: Arc, +} + +impl ConnectionCounter { + fn new() -> Self { + Self { + inner: Arc::new(ConnectionCounterInner { + count: Default::default(), + notify: Notify::new(), + }), + } + } + + /// Increase the connection count and return a guard for the new connection + fn get_one(&self) -> OneConnection { + self.inner.count.fetch_add(1, Ordering::SeqCst); + OneConnection { + inner: self.inner.clone(), + } + } + + fn is_idle(&self) -> bool { + self.inner.count.load(Ordering::SeqCst) == 0 + } + + /// Infinite stream that yields when the connection is briefly idle. + /// + /// Note that you still have to check if the connection is still idle when + /// you get the notification. + /// + /// Also note that this stream is triggered on [OneConnection::drop], so it + /// won't trigger initially even though a [ConnectionCounter] starts up as + /// idle. + fn idle_stream(self) -> impl Stream { + n0_future::stream::unfold(self, |c| async move { + c.inner.notify.notified().await; + Some(((), c)) + }) + } +} + +/// Guard for one connection +#[derive(Debug)] +struct OneConnection { + inner: Arc, +} + +impl Drop for OneConnection { + fn drop(&mut self) { + if self.inner.count.fetch_sub(1, Ordering::SeqCst) == 1 { + self.inner.notify.notify_waiters(); + } + } +} From a33d96671c313208ebca288dd51dc5ceb71e2fad Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 21 Aug 2025 12:02:10 +0200 Subject: [PATCH 02/16] clippy --- src/api/downloader.rs | 13 +++++-------- src/api/remote.rs | 2 ++ src/util/connection_pool.rs | 4 +--- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/api/downloader.rs b/src/api/downloader.rs index 52192deb3..0d5375978 100644 --- a/src/api/downloader.rs +++ b/src/api/downloader.rs @@ -4,23 +4,20 @@ use std::{ fmt::Debug, future::{Future, IntoFuture}, io, - ops::Deref, sync::Arc, - time::{Duration, SystemTime}, }; use anyhow::bail; use genawaiter::sync::Gen; -use iroh::{endpoint::Connection, Endpoint, NodeId}; +use iroh::{Endpoint, NodeId}; use irpc::{channel::mpsc, rpc_requests}; use n0_future::{future, stream, BufferedStreamExt, Stream, StreamExt}; use rand::seq::SliceRandom; use serde::{de::Error, Deserialize, Serialize}; -use tokio::{sync::Mutex, task::JoinSet}; -use tokio_util::time::FutureExt; -use tracing::{info, instrument::Instrument, warn}; +use tokio::task::JoinSet; +use tracing::instrument::Instrument; -use super::{remote::GetConnection, Store}; +use super::Store; use crate::{ protocol::{GetManyRequest, GetRequest}, util::{ @@ -445,7 +442,7 @@ async fn execute_get( request: request.clone(), }) .await?; - let mut conn = pool.connect(provider); + let conn = pool.connect(provider); let local = remote.local_for_request(request.clone()).await?; if local.is_complete() { return Ok(()); diff --git a/src/api/remote.rs b/src/api/remote.rs index 945bb7332..48cd32fb8 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -662,6 +662,8 @@ impl Remote { ) -> GetResult { let store = self.store(); let root = request.hash; + // I am cloning the connection, but it's fine because the original connection or ConnectionRef stays alive + // for the duration of the operation. let start = crate::get::fsm::start(conn.clone(), request, Default::default()); let connected = start.next().await?; trace!("Getting header"); diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs index 1ae0a918c..d4efecb19 100644 --- a/src/util/connection_pool.rs +++ b/src/util/connection_pool.rs @@ -19,7 +19,7 @@ use std::{ }; use iroh::{ - endpoint::{ConnectError, Connection}, + endpoint::ConnectError, Endpoint, NodeId, }; use n0_future::{ @@ -111,8 +111,6 @@ pub enum ConnectionPoolError { Shutdown, } -pub type PoolConnectResult = std::result::Result; - enum ActorMessage { RequestRef(RequestRef), ConnectionIdle { id: NodeId }, From 775adee562f2ba26acbb1a0cab0a0cd1d698b9ba Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 21 Aug 2025 12:06:47 +0200 Subject: [PATCH 03/16] Update docs! --- src/util/connection_pool.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs index d4efecb19..191e7864e 100644 --- a/src/util/connection_pool.rs +++ b/src/util/connection_pool.rs @@ -4,10 +4,10 @@ //! ALPN and [`Options`]. Then the pool will manage connections for you. //! //! Access to connections is via the [`ConnectionPool::connect`] method, which -//! gives you access to a connection if possible. +//! gives you access to a connection via a [`ConnectionRef`] if possible. //! -//! It is important that you use the connection only in the future passed to -//! connect, and don't clone it out of the future. +//! It is important that you keep the [`ConnectionRef`] alive while you are using +//! the connection. use std::{ collections::{HashMap, VecDeque}, ops::Deref, @@ -18,10 +18,7 @@ use std::{ time::Duration, }; -use iroh::{ - endpoint::ConnectError, - Endpoint, NodeId, -}; +use iroh::{endpoint::ConnectError, Endpoint, NodeId}; use n0_future::{ future::{self}, FuturesUnordered, MaybeFuture, Stream, StreamExt, From 24c5895fcf81aa0bb7080015343ce2c925530eb5 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 21 Aug 2025 12:15:12 +0200 Subject: [PATCH 04/16] cargo deny --- Cargo.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 19c59bac3..4068354f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3617,9 +3617,9 @@ dependencies = [ [[package]] name = "slab" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04dc19736151f35336d325007ac991178d504a119863a2fcb3758cdb5e52c50d" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" [[package]] name = "smallvec" From 806fa453bd2286fe9b0b7d91bcc8a4a45a6ed8b5 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 22 Aug 2025 09:21:09 +0200 Subject: [PATCH 05/16] Move the ConnectionPool to util and restructure a bunch of stuff so we can make util public. The ConnPool will probably move somewhere else in the longer term. --- src/api/blobs/reader.rs | 2 +- src/api/downloader.rs | 2 +- src/api/remote.rs | 3 +- src/get/request.rs | 3 +- src/lib.rs | 2 +- src/protocol.rs | 77 ++++++++++++++++++++++++++++++++-- src/protocol/range_spec.rs | 4 +- src/store/fs.rs | 2 +- src/store/mem.rs | 6 +-- src/store/readonly_mem.rs | 2 +- src/util.rs | 83 +++---------------------------------- src/util/connection_pool.rs | 4 +- 12 files changed, 93 insertions(+), 97 deletions(-) diff --git a/src/api/blobs/reader.rs b/src/api/blobs/reader.rs index e15e374d4..9e337dae1 100644 --- a/src/api/blobs/reader.rs +++ b/src/api/blobs/reader.rs @@ -221,6 +221,7 @@ mod tests { use super::*; use crate::{ + protocol::ChunkRangesExt, store::{ fs::{ tests::{create_n0_bao, test_data, INTERESTING_SIZES}, @@ -228,7 +229,6 @@ mod tests { }, mem::MemStore, }, - util::ChunkRangesExt, }; async fn reader_smoke(blobs: &Blobs) -> TestResult<()> { diff --git a/src/api/downloader.rs b/src/api/downloader.rs index 0d5375978..a2abbd7ea 100644 --- a/src/api/downloader.rs +++ b/src/api/downloader.rs @@ -442,7 +442,7 @@ async fn execute_get( request: request.clone(), }) .await?; - let conn = pool.connect(provider); + let conn = pool.get_or_connect(provider); let local = remote.local_for_request(request.clone()).await?; if local.is_complete() { return Ok(()); diff --git a/src/api/remote.rs b/src/api/remote.rs index 48cd32fb8..623200900 100644 --- a/src/api/remote.rs +++ b/src/api/remote.rs @@ -1067,7 +1067,7 @@ mod tests { use crate::{ api::blobs::Blobs, - protocol::{ChunkRangesSeq, GetRequest}, + protocol::{ChunkRangesExt, ChunkRangesSeq, GetRequest}, store::{ fs::{ tests::{create_n0_bao, test_data, INTERESTING_SIZES}, @@ -1076,7 +1076,6 @@ mod tests { mem::MemStore, }, tests::{add_test_hash_seq, add_test_hash_seq_incomplete}, - util::ChunkRangesExt, }; #[tokio::test] diff --git a/src/get/request.rs b/src/get/request.rs index 86ffcabb2..98563057e 100644 --- a/src/get/request.rs +++ b/src/get/request.rs @@ -27,8 +27,7 @@ use super::{fsm, GetError, GetResult, Stats}; use crate::{ get::error::{BadRequestSnafu, LocalFailureSnafu}, hashseq::HashSeq, - protocol::{ChunkRangesSeq, GetRequest}, - util::ChunkRangesExt, + protocol::{ChunkRangesExt, ChunkRangesSeq, GetRequest}, Hash, HashAndFormat, }; diff --git a/src/lib.rs b/src/lib.rs index ed4f78506..521ba4f7f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,7 +43,7 @@ pub mod ticket; #[doc(hidden)] pub mod test; -mod util; +pub mod util; #[cfg(test)] mod tests; diff --git a/src/protocol.rs b/src/protocol.rs index 850431996..74e0f986d 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -373,13 +373,18 @@ //! a large existing system that has demonstrated performance issues. //! //! If in doubt, just use multiple requests and multiple connections. -use std::io; +use std::{ + io, + ops::{Bound, RangeBounds}, +}; +use bao_tree::{io::round_up_to_chunks, ChunkNum}; use builder::GetRequestBuilder; use derive_more::From; use iroh::endpoint::VarInt; use irpc::util::AsyncReadVarintExt; use postcard::experimental::max_size::MaxSize; +use range_collections::{range_set::RangeSetEntry, RangeSet2}; use serde::{Deserialize, Serialize}; mod range_spec; pub use bao_tree::ChunkRanges; @@ -387,7 +392,6 @@ pub use range_spec::{ChunkRangesSeq, NonEmptyRequestRangeSpecIter, RangeSpec}; use snafu::{GenerateImplicitData, Snafu}; use tokio::io::AsyncReadExt; -pub use crate::util::ChunkRangesExt; use crate::{api::blobs::Bitfield, provider::CountingReader, BlobFormat, Hash, HashAndFormat}; /// Maximum message size is limited to 100MiB for now. @@ -714,6 +718,73 @@ impl TryFrom for Closed { } } +pub trait ChunkRangesExt { + fn last_chunk() -> Self; + fn chunk(offset: u64) -> Self; + fn bytes(ranges: impl RangeBounds) -> Self; + fn chunks(ranges: impl RangeBounds) -> Self; + fn offset(offset: u64) -> Self; +} + +impl ChunkRangesExt for ChunkRanges { + fn last_chunk() -> Self { + ChunkRanges::from(ChunkNum(u64::MAX)..) + } + + /// Create a chunk range that contains a single chunk. + fn chunk(offset: u64) -> Self { + ChunkRanges::from(ChunkNum(offset)..ChunkNum(offset + 1)) + } + + /// Create a range of chunks that contains the given byte ranges. + /// The byte ranges are rounded up to the nearest chunk size. + fn bytes(ranges: impl RangeBounds) -> Self { + round_up_to_chunks(&bounds_from_range(ranges, |v| v)) + } + + /// Create a range of chunks from u64 chunk bounds. + /// + /// This is equivalent but more convenient than using the ChunkNum newtype. + fn chunks(ranges: impl RangeBounds) -> Self { + bounds_from_range(ranges, ChunkNum) + } + + /// Create a chunk range that contains a single byte offset. + fn offset(offset: u64) -> Self { + Self::bytes(offset..offset + 1) + } +} + +// todo: move to range_collections +pub(crate) fn bounds_from_range(range: R, f: F) -> RangeSet2 +where + R: RangeBounds, + T: RangeSetEntry, + F: Fn(u64) -> T, +{ + let from = match range.start_bound() { + Bound::Included(start) => Some(*start), + Bound::Excluded(start) => { + let Some(start) = start.checked_add(1) else { + return RangeSet2::empty(); + }; + Some(start) + } + Bound::Unbounded => None, + }; + let to = match range.end_bound() { + Bound::Included(end) => end.checked_add(1), + Bound::Excluded(end) => Some(*end), + Bound::Unbounded => None, + }; + match (from, to) { + (Some(from), Some(to)) => RangeSet2::from(f(from)..f(to)), + (Some(from), None) => RangeSet2::from(f(from)..), + (None, Some(to)) => RangeSet2::from(..f(to)), + (None, None) => RangeSet2::all(), + } +} + pub mod builder { use std::collections::BTreeMap; @@ -863,7 +934,7 @@ pub mod builder { use bao_tree::ChunkNum; use super::*; - use crate::{protocol::GetManyRequest, util::ChunkRangesExt}; + use crate::protocol::{ChunkRangesExt, GetManyRequest}; #[test] fn chunk_ranges_ext() { diff --git a/src/protocol/range_spec.rs b/src/protocol/range_spec.rs index 92cfe9382..546dbe702 100644 --- a/src/protocol/range_spec.rs +++ b/src/protocol/range_spec.rs @@ -12,7 +12,7 @@ use bao_tree::{ChunkNum, ChunkRangesRef}; use serde::{Deserialize, Serialize}; use smallvec::{smallvec, SmallVec}; -pub use crate::util::ChunkRangesExt; +use crate::protocol::ChunkRangesExt; static CHUNK_RANGES_EMPTY: OnceLock = OnceLock::new(); @@ -511,7 +511,7 @@ mod tests { use proptest::prelude::*; use super::*; - use crate::util::ChunkRangesExt; + use crate::protocol::ChunkRangesExt; fn ranges(value_range: Range) -> impl Strategy { prop::collection::vec((value_range.clone(), value_range), 0..16).prop_map(|v| { diff --git a/src/store/fs.rs b/src/store/fs.rs index 9e11e098f..e8a87ad60 100644 --- a/src/store/fs.rs +++ b/src/store/fs.rs @@ -111,6 +111,7 @@ use crate::{ }, ApiClient, }, + protocol::ChunkRangesExt, store::{ fs::{ bao_file::{ @@ -125,7 +126,6 @@ use crate::{ util::{ channel::oneshot, temp_tag::{TagDrop, TempTag, TempTagScope, TempTags}, - ChunkRangesExt, }, }; mod bao_file; diff --git a/src/store/mem.rs b/src/store/mem.rs index 6d022e0f8..8a2a227b7 100644 --- a/src/store/mem.rs +++ b/src/store/mem.rs @@ -56,14 +56,12 @@ use crate::{ tags::TagInfo, ApiClient, }, + protocol::ChunkRangesExt, store::{ util::{SizeInfo, SparseMemFile, Tag}, HashAndFormat, IROH_BLOCK_SIZE, }, - util::{ - temp_tag::{TagDrop, TempTagScope, TempTags}, - ChunkRangesExt, - }, + util::temp_tag::{TagDrop, TempTagScope, TempTags}, BlobFormat, Hash, }; diff --git a/src/store/readonly_mem.rs b/src/store/readonly_mem.rs index 42274b2e2..0d9b19367 100644 --- a/src/store/readonly_mem.rs +++ b/src/store/readonly_mem.rs @@ -41,8 +41,8 @@ use crate::{ }, ApiClient, TempTag, }, + protocol::ChunkRangesExt, store::{mem::CompleteStorage, IROH_BLOCK_SIZE}, - util::ChunkRangesExt, Hash, }; diff --git a/src/util.rs b/src/util.rs index f3e493934..3fdaacbca 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,13 +1,9 @@ -use std::ops::{Bound, RangeBounds}; - -use bao_tree::{io::round_up_to_chunks, ChunkNum, ChunkRanges}; -use range_collections::{range_set::RangeSetEntry, RangeSet2}; - -pub mod channel; -pub(crate) mod connection_pool; +//! Utilities +pub(crate) mod channel; +pub mod connection_pool; pub(crate) mod temp_tag; -pub mod serde { +pub(crate) mod serde { // Module that handles io::Error serialization/deserialization pub mod io_error_serde { use std::{fmt, io}; @@ -218,74 +214,7 @@ pub mod serde { } } -pub trait ChunkRangesExt { - fn last_chunk() -> Self; - fn chunk(offset: u64) -> Self; - fn bytes(ranges: impl RangeBounds) -> Self; - fn chunks(ranges: impl RangeBounds) -> Self; - fn offset(offset: u64) -> Self; -} - -impl ChunkRangesExt for ChunkRanges { - fn last_chunk() -> Self { - ChunkRanges::from(ChunkNum(u64::MAX)..) - } - - /// Create a chunk range that contains a single chunk. - fn chunk(offset: u64) -> Self { - ChunkRanges::from(ChunkNum(offset)..ChunkNum(offset + 1)) - } - - /// Create a range of chunks that contains the given byte ranges. - /// The byte ranges are rounded up to the nearest chunk size. - fn bytes(ranges: impl RangeBounds) -> Self { - round_up_to_chunks(&bounds_from_range(ranges, |v| v)) - } - - /// Create a range of chunks from u64 chunk bounds. - /// - /// This is equivalent but more convenient than using the ChunkNum newtype. - fn chunks(ranges: impl RangeBounds) -> Self { - bounds_from_range(ranges, ChunkNum) - } - - /// Create a chunk range that contains a single byte offset. - fn offset(offset: u64) -> Self { - Self::bytes(offset..offset + 1) - } -} - -// todo: move to range_collections -pub(crate) fn bounds_from_range(range: R, f: F) -> RangeSet2 -where - R: RangeBounds, - T: RangeSetEntry, - F: Fn(u64) -> T, -{ - let from = match range.start_bound() { - Bound::Included(start) => Some(*start), - Bound::Excluded(start) => { - let Some(start) = start.checked_add(1) else { - return RangeSet2::empty(); - }; - Some(start) - } - Bound::Unbounded => None, - }; - let to = match range.end_bound() { - Bound::Included(end) => end.checked_add(1), - Bound::Excluded(end) => Some(*end), - Bound::Unbounded => None, - }; - match (from, to) { - (Some(from), Some(to)) => RangeSet2::from(f(from)..f(to)), - (Some(from), None) => RangeSet2::from(f(from)..), - (None, Some(to)) => RangeSet2::from(..f(to)), - (None, None) => RangeSet2::all(), - } -} - -pub mod outboard_with_progress { +pub(crate) mod outboard_with_progress { use std::io::{self, BufReader, Read}; use bao_tree::{ @@ -433,7 +362,7 @@ pub mod outboard_with_progress { } } -pub mod sink { +pub(crate) mod sink { use std::{future::Future, io}; use irpc::RpcMessage; diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs index 191e7864e..7b283866d 100644 --- a/src/util/connection_pool.rs +++ b/src/util/connection_pool.rs @@ -3,7 +3,7 @@ //! Entry point is [`ConnectionPool`]. You create a connection pool for a specific //! ALPN and [`Options`]. Then the pool will manage connections for you. //! -//! Access to connections is via the [`ConnectionPool::connect`] method, which +//! Access to connections is via the [`ConnectionPool::get_or_connect`] method, which //! gives you access to a connection via a [`ConnectionRef`] if possible. //! //! It is important that you keep the [`ConnectionRef`] alive while you are using @@ -360,7 +360,7 @@ impl ConnectionPool { /// /// This is guaranteed to return after approximately [Options::connect_timeout] /// with either an error or a connection. - pub async fn connect( + pub async fn get_or_connect( &self, id: NodeId, ) -> std::result::Result { From 0edac5cf05bae2caacd560e6ed7008e80152c0ff Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Tue, 26 Aug 2025 11:19:42 +0200 Subject: [PATCH 06/16] Add config object --- src/util/connection_pool.rs | 35 +++++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs index 7b283866d..07f3d952f 100644 --- a/src/util/connection_pool.rs +++ b/src/util/connection_pool.rs @@ -29,7 +29,7 @@ use tokio::sync::{ oneshot, Notify, }; use tokio_util::time::FutureExt as TimeFutureExt; -use tracing::{debug, error, trace}; +use tracing::{debug, error, info, trace}; /// Configuration options for the connection pool #[derive(Debug, Clone, Copy)] @@ -142,17 +142,26 @@ impl Context { .await .map_err(|_| PoolConnectError::Timeout) .and_then(|r| r.map_err(PoolConnectError::from)); - if let Err(e) = &state { - debug!(%node_id, "Failed to connect {e:?}, requesting shutdown"); - if context.owner.close(node_id).await.is_err() { - return; + let conn_close = match &state { + Ok(conn) => { + let conn = conn.clone(); + MaybeFuture::Some(async move { conn.closed().await }) } - } + Err(e) => { + debug!(%node_id, "Failed to connect {e:?}, requesting shutdown"); + tokio::time::sleep(Duration::from_secs(1)).await; + if context.owner.close(node_id).await.is_err() { + return; + } + MaybeFuture::None + } + }; + let counter = ConnectionCounter::new(); let idle_timer = MaybeFuture::default(); let idle_stream = counter.clone().idle_stream(); - tokio::pin!(idle_timer, idle_stream); + tokio::pin!(idle_timer, idle_stream, conn_close); loop { tokio::select! { @@ -166,6 +175,7 @@ impl Context { match &state { Ok(state) => { let res = ConnectionRef::new(state.clone(), counter.get_one()); + info!(%node_id, "Handing out ConnectionRef {}", counter.current()); // clear the idle timer idle_timer.as_mut().set_none(); @@ -177,12 +187,17 @@ impl Context { } } None => { - // Channel closed - finish remaining tasks and exit + // Channel closed - exit break; } } } + _ = &mut conn_close => { + // connection was closed by somebody, notify owner that we should be removed + context.owner.close(node_id).await.ok(); + } + _ = idle_stream.next() => { if !counter.is_idle() { continue; @@ -417,6 +432,10 @@ impl ConnectionCounter { } } + fn current(&self) -> usize { + self.inner.count.load(Ordering::SeqCst) + } + /// Increase the connection count and return a guard for the new connection fn get_one(&self) -> OneConnection { self.inner.count.fetch_add(1, Ordering::SeqCst); From 960bdfd7baca684a05876d07db290ce79bfb9132 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 27 Aug 2025 11:24:02 +0200 Subject: [PATCH 07/16] Add back tests from iroh-experiments/iroh-connection-pool --- src/util/connection_pool.rs | 230 ++++++++++++++++++++++++++++++++++++ 1 file changed, 230 insertions(+) diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs index 07f3d952f..42352576a 100644 --- a/src/util/connection_pool.rs +++ b/src/util/connection_pool.rs @@ -477,3 +477,233 @@ impl Drop for OneConnection { } } } + +#[cfg(test)] +mod tests { + use std::{collections::BTreeMap, time::Duration}; + + use iroh::{ + discovery::static_provider::StaticProvider, + endpoint::Connection, + protocol::{AcceptError, ProtocolHandler, Router}, + NodeAddr, NodeId, SecretKey, Watcher, + }; + use n0_future::{stream, BufferedStreamExt, StreamExt}; + use n0_snafu::ResultExt; + use testresult::TestResult; + use tracing::trace; + + use super::{ConnectionPool, Options, PoolConnectError}; + + const ECHO_ALPN: &[u8] = b"echo"; + + #[derive(Debug, Clone)] + struct Echo; + + impl ProtocolHandler for Echo { + async fn accept(&self, connection: Connection) -> Result<(), AcceptError> { + let conn_id = connection.stable_id(); + let id = connection.remote_node_id().map_err(AcceptError::from_err)?; + trace!(%id, %conn_id, "Accepting echo connection"); + loop { + match connection.accept_bi().await { + Ok((mut send, mut recv)) => { + trace!(%id, %conn_id, "Accepted echo request"); + tokio::io::copy(&mut recv, &mut send).await?; + send.finish().map_err(AcceptError::from_err)?; + } + Err(e) => { + trace!(%id, %conn_id, "Failed to accept echo request {e}"); + break; + } + } + } + Ok(()) + } + } + + async fn echo_client(conn: &Connection, text: &[u8]) -> n0_snafu::Result> { + let conn_id = conn.stable_id(); + let id = conn.remote_node_id().e()?; + trace!(%id, %conn_id, "Sending echo request"); + let (mut send, mut recv) = conn.open_bi().await.e()?; + send.write_all(text).await.e()?; + send.finish().e()?; + let response = recv.read_to_end(1000).await.e()?; + trace!(%id, %conn_id, "Received echo response"); + Ok(response) + } + + async fn echo_server() -> TestResult<(NodeAddr, Router)> { + let endpoint = iroh::Endpoint::builder() + .alpns(vec![ECHO_ALPN.to_vec()]) + .bind() + .await?; + let addr = endpoint.node_addr().initialized().await; + let router = iroh::protocol::Router::builder(endpoint) + .accept(ECHO_ALPN, Echo) + .spawn(); + + Ok((addr, router)) + } + + async fn echo_servers(n: usize) -> TestResult> { + stream::iter(0..n) + .map(|_| echo_server()) + .buffered_unordered(16) + .collect::>() + .await + .into_iter() + .collect() + } + + fn test_options() -> Options { + Options { + idle_timeout: Duration::from_millis(100), + connect_timeout: Duration::from_secs(2), + max_connections: 32, + on_connect: None, + } + } + + struct EchoClient { + pool: ConnectionPool, + } + + impl EchoClient { + async fn echo( + &self, + id: NodeId, + text: Vec, + ) -> Result), n0_snafu::Error>, PoolConnectError> { + let conn = self.pool.get_or_connect(id).await?; + let id = conn.stable_id(); + match echo_client(&conn, &text).await { + Ok(res) => Ok(Ok((id, res))), + Err(e) => Ok(Err(e)), + } + } + } + + #[tokio::test] + async fn connection_pool_errors() -> TestResult<()> { + let filter = tracing_subscriber::EnvFilter::from_default_env(); + tracing_subscriber::fmt() + .with_env_filter(filter) + .try_init() + .ok(); + // set up static discovery for all addrs + let discovery = StaticProvider::new(); + let endpoint = iroh::Endpoint::builder() + .discovery(discovery.clone()) + .bind() + .await?; + let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options()); + let client = EchoClient { pool }; + { + let non_existing = SecretKey::from_bytes(&[0; 32]).public(); + let res = client.echo(non_existing, b"Hello, world!".to_vec()).await; + // trying to connect to a non-existing id will fail with ConnectError + // because we don't have any information about the node + assert!(matches!(res, Err(PoolConnectError::ConnectError { .. }))); + } + { + let non_listening = SecretKey::from_bytes(&[0; 32]).public(); + // make up fake node info + discovery.add_node_info(NodeAddr { + node_id: non_listening, + relay_url: None, + direct_addresses: vec!["127.0.0.1:12121".parse().unwrap()] + .into_iter() + .collect(), + }); + // trying to connect to an id for which we have info, but the other + // end is not listening, will lead to a timeout. + let res = client.echo(non_listening, b"Hello, world!".to_vec()).await; + assert!(matches!(res, Err(PoolConnectError::Timeout))); + } + Ok(()) + } + + #[tokio::test] + async fn connection_pool_smoke() -> TestResult<()> { + let filter = tracing_subscriber::EnvFilter::from_default_env(); + tracing_subscriber::fmt() + .with_env_filter(filter) + .try_init() + .ok(); + let n = 32; + let nodes = echo_servers(n).await?; + let ids = nodes + .iter() + .map(|(addr, _)| addr.node_id) + .collect::>(); + // set up static discovery for all addrs + let discovery = StaticProvider::from_node_info(nodes.iter().map(|(addr, _)| addr.clone())); + // build a client endpoint that can resolve all the node ids + let endpoint = iroh::Endpoint::builder() + .discovery(discovery.clone()) + .bind() + .await?; + let pool = ConnectionPool::new(endpoint.clone(), ECHO_ALPN, test_options()); + let client = EchoClient { pool }; + let mut connection_ids = BTreeMap::new(); + let msg = b"Hello, world!".to_vec(); + for id in &ids { + let (cid1, res) = client.echo(*id, msg.clone()).await??; + assert_eq!(res, msg); + let (cid2, res) = client.echo(*id, msg.clone()).await??; + assert_eq!(res, msg); + assert_eq!(cid1, cid2); + connection_ids.insert(id, cid1); + } + tokio::time::sleep(Duration::from_millis(1000)).await; + for id in &ids { + let cid1 = *connection_ids.get(id).expect("Connection ID not found"); + let (cid2, res) = client.echo(*id, msg.clone()).await??; + assert_eq!(res, msg); + assert_ne!(cid1, cid2); + } + Ok(()) + } + + /// Tests that idle connections are being reclaimed to make room if we hit the + /// maximum connection limit. + #[tokio::test] + async fn connection_pool_idle() -> TestResult<()> { + let filter = tracing_subscriber::EnvFilter::from_default_env(); + tracing_subscriber::fmt() + .with_env_filter(filter) + .try_init() + .ok(); + let n = 32; + let nodes = echo_servers(n).await?; + let ids = nodes + .iter() + .map(|(addr, _)| addr.node_id) + .collect::>(); + // set up static discovery for all addrs + let discovery = StaticProvider::from_node_info(nodes.iter().map(|(addr, _)| addr.clone())); + // build a client endpoint that can resolve all the node ids + let endpoint = iroh::Endpoint::builder() + .discovery(discovery.clone()) + .bind() + .await?; + let pool = ConnectionPool::new( + endpoint.clone(), + ECHO_ALPN, + Options { + idle_timeout: Duration::from_secs(100), + max_connections: 8, + ..test_options() + }, + ); + let client = EchoClient { pool }; + let msg = b"Hello, world!".to_vec(); + for id in &ids { + let (_, res) = client.echo(*id, msg.clone()).await??; + assert_eq!(res, msg); + } + Ok(()) + } +} From a32e159f5eb9d806e1a68d3c81f0c1eabca31047 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 27 Aug 2025 11:24:16 +0200 Subject: [PATCH 08/16] Implement on_connect cb. --- src/util/connection_pool.rs | 49 +++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs index 42352576a..ff6c12c8f 100644 --- a/src/util/connection_pool.rs +++ b/src/util/connection_pool.rs @@ -10,6 +10,7 @@ //! the connection. use std::{ collections::{HashMap, VecDeque}, + io, ops::Deref, sync::{ atomic::{AtomicUsize, Ordering}, @@ -18,7 +19,10 @@ use std::{ time::Duration, }; -use iroh::{endpoint::ConnectError, Endpoint, NodeId}; +use iroh::{ + endpoint::{ConnectError, Connection}, + Endpoint, NodeId, +}; use n0_future::{ future::{self}, FuturesUnordered, MaybeFuture, Stream, StreamExt, @@ -31,12 +35,23 @@ use tokio::sync::{ use tokio_util::time::FutureExt as TimeFutureExt; use tracing::{debug, error, info, trace}; +pub type OnConnected = + Arc n0_future::future::Boxed> + Send + Sync>; + /// Configuration options for the connection pool -#[derive(Debug, Clone, Copy)] +#[derive(derive_more::Debug, Clone)] pub struct Options { + /// How long to keep idle connections around. pub idle_timeout: Duration, + /// Timeout for connect. This includes the time spent in on_connect, if set. pub connect_timeout: Duration, + /// Maximum number of connections to hand out. pub max_connections: usize, + /// An optional callback that can be used to wait for the connection to enter some state. + /// An example usage could be to wait for the connection to become direct before handing + /// it out to the user. + #[debug(skip)] + pub on_connect: Option, } impl Default for Options { @@ -45,6 +60,7 @@ impl Default for Options { idle_timeout: Duration::from_secs(5), connect_timeout: Duration::from_secs(1), max_connections: 1024, + on_connect: None, } } } @@ -88,6 +104,8 @@ pub enum PoolConnectError { TooManyConnections, /// Error during connect ConnectError { source: Arc }, + /// Error during on_connect callback + OnConnectError { source: Arc }, } impl From for PoolConnectError { @@ -98,6 +116,14 @@ impl From for PoolConnectError { } } +impl From for PoolConnectError { + fn from(e: io::Error) -> Self { + PoolConnectError::OnConnectError { + source: Arc::new(e), + } + } +} + /// Error when calling a fn on the [`ConnectionPool`]. /// /// The only thing that can go wrong is that the connection pool is shut down. @@ -134,10 +160,23 @@ impl Context { ) { let context = self; + let context2 = context.clone(); + let conn_fut = async move { + let conn = context2 + .endpoint + .connect(node_id, &context2.alpn) + .await + .map_err(PoolConnectError::from)?; + if let Some(on_connect) = &context2.options.on_connect { + on_connect(&context2.endpoint, &conn) + .await + .map_err(PoolConnectError::from)?; + } + Result::::Ok(conn) + }; + // Connect to the node - let state = context - .endpoint - .connect(node_id, &context.alpn) + let state = conn_fut .timeout(context.options.connect_timeout) .await .map_err(|_| PoolConnectError::Timeout) From 8ba7d91cb8471429d0aceba5ac7ede444c865ceb Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 27 Aug 2025 12:29:51 +0200 Subject: [PATCH 09/16] Add test for connection pool on_connected timeout. Also shut down routers. --- src/util/connection_pool.rs | 143 +++++++++++++++++++++++++----------- 1 file changed, 101 insertions(+), 42 deletions(-) diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs index ff6c12c8f..91fffdd19 100644 --- a/src/util/connection_pool.rs +++ b/src/util/connection_pool.rs @@ -51,7 +51,7 @@ pub struct Options { /// An example usage could be to wait for the connection to become direct before handing /// it out to the user. #[debug(skip)] - pub on_connect: Option, + pub on_connected: Option, } impl Default for Options { @@ -60,7 +60,7 @@ impl Default for Options { idle_timeout: Duration::from_secs(5), connect_timeout: Duration::from_secs(1), max_connections: 1024, - on_connect: None, + on_connected: None, } } } @@ -167,7 +167,7 @@ impl Context { .connect(node_id, &context2.alpn) .await .map_err(PoolConnectError::from)?; - if let Some(on_connect) = &context2.options.on_connect { + if let Some(on_connect) = &context2.options.on_connected { on_connect(&context2.endpoint, &conn) .await .map_err(PoolConnectError::from)?; @@ -519,7 +519,7 @@ impl Drop for OneConnection { #[cfg(test)] mod tests { - use std::{collections::BTreeMap, time::Duration}; + use std::{collections::BTreeMap, sync::Arc, time::Duration}; use iroh::{ discovery::static_provider::StaticProvider, @@ -527,12 +527,14 @@ mod tests { protocol::{AcceptError, ProtocolHandler, Router}, NodeAddr, NodeId, SecretKey, Watcher, }; - use n0_future::{stream, BufferedStreamExt, StreamExt}; + use n0_future::{io, stream, BufferedStreamExt, StreamExt}; use n0_snafu::ResultExt; use testresult::TestResult; use tracing::trace; + use tracing_test::traced_test; use super::{ConnectionPool, Options, PoolConnectError}; + use crate::util::connection_pool::OnConnected; const ECHO_ALPN: &[u8] = b"echo"; @@ -586,14 +588,25 @@ mod tests { Ok((addr, router)) } - async fn echo_servers(n: usize) -> TestResult> { - stream::iter(0..n) + async fn echo_servers(n: usize) -> TestResult<(Vec, Vec, StaticProvider)> { + let res = stream::iter(0..n) .map(|_| echo_server()) .buffered_unordered(16) .collect::>() - .await - .into_iter() - .collect() + .await; + let res: Vec<(NodeAddr, Router)> = res.into_iter().collect::>>()?; + let (addrs, routers): (Vec<_>, Vec<_>) = res.into_iter().unzip(); + let ids = addrs.iter().map(|a| a.node_id).collect::>(); + let discovery = StaticProvider::from_node_info(addrs); + Ok((ids, routers, discovery)) + } + + async fn shutdown_routers(routers: Vec) { + stream::iter(routers) + .for_each_concurrent(16, |router| async move { + let _ = router.shutdown().await; + }) + .await; } fn test_options() -> Options { @@ -601,7 +614,7 @@ mod tests { idle_timeout: Duration::from_millis(100), connect_timeout: Duration::from_secs(2), max_connections: 32, - on_connect: None, + on_connected: None, } } @@ -625,12 +638,8 @@ mod tests { } #[tokio::test] + #[traced_test] async fn connection_pool_errors() -> TestResult<()> { - let filter = tracing_subscriber::EnvFilter::from_default_env(); - tracing_subscriber::fmt() - .with_env_filter(filter) - .try_init() - .ok(); // set up static discovery for all addrs let discovery = StaticProvider::new(); let endpoint = iroh::Endpoint::builder() @@ -665,20 +674,10 @@ mod tests { } #[tokio::test] + #[traced_test] async fn connection_pool_smoke() -> TestResult<()> { - let filter = tracing_subscriber::EnvFilter::from_default_env(); - tracing_subscriber::fmt() - .with_env_filter(filter) - .try_init() - .ok(); let n = 32; - let nodes = echo_servers(n).await?; - let ids = nodes - .iter() - .map(|(addr, _)| addr.node_id) - .collect::>(); - // set up static discovery for all addrs - let discovery = StaticProvider::from_node_info(nodes.iter().map(|(addr, _)| addr.clone())); + let (ids, routers, discovery) = echo_servers(n).await?; // build a client endpoint that can resolve all the node ids let endpoint = iroh::Endpoint::builder() .discovery(discovery.clone()) @@ -687,7 +686,7 @@ mod tests { let pool = ConnectionPool::new(endpoint.clone(), ECHO_ALPN, test_options()); let client = EchoClient { pool }; let mut connection_ids = BTreeMap::new(); - let msg = b"Hello, world!".to_vec(); + let msg = b"Hello, pool!".to_vec(); for id in &ids { let (cid1, res) = client.echo(*id, msg.clone()).await??; assert_eq!(res, msg); @@ -703,26 +702,17 @@ mod tests { assert_eq!(res, msg); assert_ne!(cid1, cid2); } + shutdown_routers(routers).await; Ok(()) } /// Tests that idle connections are being reclaimed to make room if we hit the /// maximum connection limit. #[tokio::test] + #[traced_test] async fn connection_pool_idle() -> TestResult<()> { - let filter = tracing_subscriber::EnvFilter::from_default_env(); - tracing_subscriber::fmt() - .with_env_filter(filter) - .try_init() - .ok(); let n = 32; - let nodes = echo_servers(n).await?; - let ids = nodes - .iter() - .map(|(addr, _)| addr.node_id) - .collect::>(); - // set up static discovery for all addrs - let discovery = StaticProvider::from_node_info(nodes.iter().map(|(addr, _)| addr.clone())); + let (ids, routers, discovery) = echo_servers(n).await?; // build a client endpoint that can resolve all the node ids let endpoint = iroh::Endpoint::builder() .discovery(discovery.clone()) @@ -738,11 +728,80 @@ mod tests { }, ); let client = EchoClient { pool }; - let msg = b"Hello, world!".to_vec(); + let msg = b"Hello, pool!".to_vec(); for id in &ids { let (_, res) = client.echo(*id, msg.clone()).await??; assert_eq!(res, msg); } + shutdown_routers(routers).await; + Ok(()) + } + + /// Uses an on_connected callback that just errors out every time. + /// + /// This is a basic smoke test that on_connected gets called at all. + #[tokio::test] + #[traced_test] + async fn on_connected_error() -> TestResult<()> { + let n = 1; + let (ids, routers, discovery) = echo_servers(n).await?; + let endpoint = iroh::Endpoint::builder() + .discovery(discovery) + .bind() + .await?; + let on_connected: OnConnected = + Arc::new(|_, _| Box::pin(async { Err(io::Error::other("on_connect failed")) })); + let pool = ConnectionPool::new( + endpoint, + ECHO_ALPN, + Options { + on_connected: Some(on_connected), + ..test_options() + }, + ); + let client = EchoClient { pool }; + let msg = b"Hello, pool!".to_vec(); + for id in &ids { + let res = client.echo(*id, msg.clone()).await; + assert!(matches!(res, Err(PoolConnectError::OnConnectError { .. }))); + } + shutdown_routers(routers).await; + Ok(()) + } + + /// Uses an on_connected callback that delays for a long time. + /// + /// This checks that the pool timeout includes on_connected delay. + #[tokio::test] + #[traced_test] + async fn on_connected_timeout() -> TestResult<()> { + let n = 1; + let (ids, routers, discovery) = echo_servers(n).await?; + let endpoint = iroh::Endpoint::builder() + .discovery(discovery) + .bind() + .await?; + let on_connected: OnConnected = Arc::new(|_, _| { + Box::pin(async { + tokio::time::sleep(Duration::from_secs(2)).await; + Ok(()) + }) + }); + let pool = ConnectionPool::new( + endpoint, + ECHO_ALPN, + Options { + on_connected: Some(on_connected), + ..test_options() + }, + ); + let client = EchoClient { pool }; + let msg = b"Hello, pool!".to_vec(); + for id in &ids { + let res = client.echo(*id, msg.clone()).await; + assert!(matches!(res, Err(PoolConnectError::Timeout { .. }))); + } + shutdown_routers(routers).await; Ok(()) } } From 4ca9f2b6f69417b6f2e646d17ffeddda7b8888b6 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 27 Aug 2025 12:51:19 +0200 Subject: [PATCH 10/16] Add test for connection close. --- src/util/connection_pool.rs | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs index 91fffdd19..76d0f408e 100644 --- a/src/util/connection_pool.rs +++ b/src/util/connection_pool.rs @@ -804,4 +804,30 @@ mod tests { shutdown_routers(routers).await; Ok(()) } + + /// Check that when a connection is closed, the pool will give you a new + /// connection next time you want one. + /// + /// This test fails if the connection watch is disabled. + #[tokio::test] + #[traced_test] + async fn watch_close() -> TestResult<()> { + let n = 1; + let (ids, routers, discovery) = echo_servers(n).await?; + let endpoint = iroh::Endpoint::builder() + .discovery(discovery) + .bind() + .await?; + + let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options()); + let conn = pool.get_or_connect(ids[0]).await?; + let cid1 = conn.stable_id(); + conn.close(0u32.into(), b"test"); + tokio::time::sleep(Duration::from_millis(500)).await; + let conn = pool.get_or_connect(ids[0]).await?; + let cid2 = conn.stable_id(); + assert_ne!(cid1, cid2); + shutdown_routers(routers).await; + Ok(()) + } } From 3dc1c8b7a88dd5918080059bb5c296e1cea67a2f Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 27 Aug 2025 12:57:28 +0200 Subject: [PATCH 11/16] clippy --- src/util/connection_pool.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs index 76d0f408e..042bc58c6 100644 --- a/src/util/connection_pool.rs +++ b/src/util/connection_pool.rs @@ -180,7 +180,7 @@ impl Context { .timeout(context.options.connect_timeout) .await .map_err(|_| PoolConnectError::Timeout) - .and_then(|r| r.map_err(PoolConnectError::from)); + .and_then(|r| r); let conn_close = match &state { Ok(conn) => { let conn = conn.clone(); @@ -799,7 +799,7 @@ mod tests { let msg = b"Hello, pool!".to_vec(); for id in &ids { let res = client.echo(*id, msg.clone()).await; - assert!(matches!(res, Err(PoolConnectError::Timeout { .. }))); + assert!(matches!(res, Err(PoolConnectError::Timeout))); } shutdown_routers(routers).await; Ok(()) From 03e729b3b9d10c2a0f88f37afadb27b2d2d9b2a7 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 27 Aug 2025 15:25:03 +0200 Subject: [PATCH 12/16] Wait for relay --- src/util/connection_pool.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs index 042bc58c6..ab0488cf0 100644 --- a/src/util/connection_pool.rs +++ b/src/util/connection_pool.rs @@ -580,6 +580,7 @@ mod tests { .alpns(vec![ECHO_ALPN.to_vec()]) .bind() .await?; + endpoint.home_relay().initialized().await; let addr = endpoint.node_addr().initialized().await; let router = iroh::protocol::Router::builder(endpoint) .accept(ECHO_ALPN, Echo) From 6a46a0b5a91d6dc86255b781bad57856e4aba360 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 27 Aug 2025 15:31:42 +0200 Subject: [PATCH 13/16] Increase connect timeout. --- src/util/connection_pool.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs index ab0488cf0..5c0d01823 100644 --- a/src/util/connection_pool.rs +++ b/src/util/connection_pool.rs @@ -613,7 +613,7 @@ mod tests { fn test_options() -> Options { Options { idle_timeout: Duration::from_millis(100), - connect_timeout: Duration::from_secs(2), + connect_timeout: Duration::from_secs(5), max_connections: 32, on_connected: None, } @@ -784,7 +784,7 @@ mod tests { .await?; let on_connected: OnConnected = Arc::new(|_, _| { Box::pin(async { - tokio::time::sleep(Duration::from_secs(2)).await; + tokio::time::sleep(Duration::from_secs(20)).await; Ok(()) }) }); From 93053c9868b4055899e7b5fa7c6ae41a03977e6d Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 27 Aug 2025 15:40:23 +0200 Subject: [PATCH 14/16] disable traced_test it seems to cause weird problems with cross test. --- src/util/connection_pool.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs index 5c0d01823..bb8e1934c 100644 --- a/src/util/connection_pool.rs +++ b/src/util/connection_pool.rs @@ -531,7 +531,6 @@ mod tests { use n0_snafu::ResultExt; use testresult::TestResult; use tracing::trace; - use tracing_test::traced_test; use super::{ConnectionPool, Options, PoolConnectError}; use crate::util::connection_pool::OnConnected; @@ -639,7 +638,7 @@ mod tests { } #[tokio::test] - #[traced_test] + // #[traced_test] async fn connection_pool_errors() -> TestResult<()> { // set up static discovery for all addrs let discovery = StaticProvider::new(); @@ -675,7 +674,7 @@ mod tests { } #[tokio::test] - #[traced_test] + // #[traced_test] async fn connection_pool_smoke() -> TestResult<()> { let n = 32; let (ids, routers, discovery) = echo_servers(n).await?; @@ -710,7 +709,7 @@ mod tests { /// Tests that idle connections are being reclaimed to make room if we hit the /// maximum connection limit. #[tokio::test] - #[traced_test] + // #[traced_test] async fn connection_pool_idle() -> TestResult<()> { let n = 32; let (ids, routers, discovery) = echo_servers(n).await?; @@ -742,7 +741,7 @@ mod tests { /// /// This is a basic smoke test that on_connected gets called at all. #[tokio::test] - #[traced_test] + // #[traced_test] async fn on_connected_error() -> TestResult<()> { let n = 1; let (ids, routers, discovery) = echo_servers(n).await?; @@ -774,7 +773,7 @@ mod tests { /// /// This checks that the pool timeout includes on_connected delay. #[tokio::test] - #[traced_test] + // #[traced_test] async fn on_connected_timeout() -> TestResult<()> { let n = 1; let (ids, routers, discovery) = echo_servers(n).await?; @@ -811,7 +810,7 @@ mod tests { /// /// This test fails if the connection watch is disabled. #[tokio::test] - #[traced_test] + // #[traced_test] async fn watch_close() -> TestResult<()> { let n = 1; let (ids, routers, discovery) = echo_servers(n).await?; From 030e8d914cd33c77e0fa05378d72642645b1f624 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=BCdiger=20Klaehn?= Date: Thu, 28 Aug 2025 09:49:15 +0200 Subject: [PATCH 15/16] Update src/util/connection_pool.rs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Philipp Krüger --- src/util/connection_pool.rs | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs index bb8e1934c..f7187a5e9 100644 --- a/src/util/connection_pool.rs +++ b/src/util/connection_pool.rs @@ -160,19 +160,21 @@ impl Context { ) { let context = self; - let context2 = context.clone(); - let conn_fut = async move { - let conn = context2 - .endpoint - .connect(node_id, &context2.alpn) - .await - .map_err(PoolConnectError::from)?; - if let Some(on_connect) = &context2.options.on_connected { - on_connect(&context2.endpoint, &conn) + let conn_fut = { + let context = context.clone(); + async move { + let conn = context + .endpoint + .connect(node_id, &context.alpn) .await .map_err(PoolConnectError::from)?; + if let Some(on_connect) = &context.options.on_connected { + on_connect(&context.endpoint, &conn) + .await + .map_err(PoolConnectError::from)?; + } + Result::::Ok(conn) } - Result::::Ok(conn) }; // Connect to the node From e0edf7779eaa3a168b2d9d60bde5c9a1bdd99720 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Thu, 28 Aug 2025 09:53:14 +0200 Subject: [PATCH 16/16] Remove weird delay --- src/util/connection_pool.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs index f7187a5e9..aa9c15292 100644 --- a/src/util/connection_pool.rs +++ b/src/util/connection_pool.rs @@ -190,7 +190,6 @@ impl Context { } Err(e) => { debug!(%node_id, "Failed to connect {e:?}, requesting shutdown"); - tokio::time::sleep(Duration::from_secs(1)).await; if context.owner.close(node_id).await.is_err() { return; }