diff --git a/src/util/connection_pool.rs b/src/util/connection_pool.rs index 7b283866d..aa9c15292 100644 --- a/src/util/connection_pool.rs +++ b/src/util/connection_pool.rs @@ -10,6 +10,7 @@ //! the connection. use std::{ collections::{HashMap, VecDeque}, + io, ops::Deref, sync::{ atomic::{AtomicUsize, Ordering}, @@ -18,7 +19,10 @@ use std::{ time::Duration, }; -use iroh::{endpoint::ConnectError, Endpoint, NodeId}; +use iroh::{ + endpoint::{ConnectError, Connection}, + Endpoint, NodeId, +}; use n0_future::{ future::{self}, FuturesUnordered, MaybeFuture, Stream, StreamExt, @@ -29,14 +33,25 @@ use tokio::sync::{ oneshot, Notify, }; use tokio_util::time::FutureExt as TimeFutureExt; -use tracing::{debug, error, trace}; +use tracing::{debug, error, info, trace}; + +pub type OnConnected = + Arc n0_future::future::Boxed> + Send + Sync>; /// Configuration options for the connection pool -#[derive(Debug, Clone, Copy)] +#[derive(derive_more::Debug, Clone)] pub struct Options { + /// How long to keep idle connections around. pub idle_timeout: Duration, + /// Timeout for connect. This includes the time spent in on_connect, if set. pub connect_timeout: Duration, + /// Maximum number of connections to hand out. pub max_connections: usize, + /// An optional callback that can be used to wait for the connection to enter some state. + /// An example usage could be to wait for the connection to become direct before handing + /// it out to the user. + #[debug(skip)] + pub on_connected: Option, } impl Default for Options { @@ -45,6 +60,7 @@ impl Default for Options { idle_timeout: Duration::from_secs(5), connect_timeout: Duration::from_secs(1), max_connections: 1024, + on_connected: None, } } } @@ -88,6 +104,8 @@ pub enum PoolConnectError { TooManyConnections, /// Error during connect ConnectError { source: Arc }, + /// Error during on_connect callback + OnConnectError { source: Arc }, } impl From for PoolConnectError { @@ -98,6 +116,14 @@ impl From for PoolConnectError { } } +impl From for PoolConnectError { + fn from(e: io::Error) -> Self { + PoolConnectError::OnConnectError { + source: Arc::new(e), + } + } +} + /// Error when calling a fn on the [`ConnectionPool`]. /// /// The only thing that can go wrong is that the connection pool is shut down. @@ -134,25 +160,48 @@ impl Context { ) { let context = self; + let conn_fut = { + let context = context.clone(); + async move { + let conn = context + .endpoint + .connect(node_id, &context.alpn) + .await + .map_err(PoolConnectError::from)?; + if let Some(on_connect) = &context.options.on_connected { + on_connect(&context.endpoint, &conn) + .await + .map_err(PoolConnectError::from)?; + } + Result::::Ok(conn) + } + }; + // Connect to the node - let state = context - .endpoint - .connect(node_id, &context.alpn) + let state = conn_fut .timeout(context.options.connect_timeout) .await .map_err(|_| PoolConnectError::Timeout) - .and_then(|r| r.map_err(PoolConnectError::from)); - if let Err(e) = &state { - debug!(%node_id, "Failed to connect {e:?}, requesting shutdown"); - if context.owner.close(node_id).await.is_err() { - return; + .and_then(|r| r); + let conn_close = match &state { + Ok(conn) => { + let conn = conn.clone(); + MaybeFuture::Some(async move { conn.closed().await }) } - } + Err(e) => { + debug!(%node_id, "Failed to connect {e:?}, requesting shutdown"); + if context.owner.close(node_id).await.is_err() { + return; + } + MaybeFuture::None + } + }; + let counter = ConnectionCounter::new(); let idle_timer = MaybeFuture::default(); let idle_stream = counter.clone().idle_stream(); - tokio::pin!(idle_timer, idle_stream); + tokio::pin!(idle_timer, idle_stream, conn_close); loop { tokio::select! { @@ -166,6 +215,7 @@ impl Context { match &state { Ok(state) => { let res = ConnectionRef::new(state.clone(), counter.get_one()); + info!(%node_id, "Handing out ConnectionRef {}", counter.current()); // clear the idle timer idle_timer.as_mut().set_none(); @@ -177,12 +227,17 @@ impl Context { } } None => { - // Channel closed - finish remaining tasks and exit + // Channel closed - exit break; } } } + _ = &mut conn_close => { + // connection was closed by somebody, notify owner that we should be removed + context.owner.close(node_id).await.ok(); + } + _ = idle_stream.next() => { if !counter.is_idle() { continue; @@ -417,6 +472,10 @@ impl ConnectionCounter { } } + fn current(&self) -> usize { + self.inner.count.load(Ordering::SeqCst) + } + /// Increase the connection count and return a guard for the new connection fn get_one(&self) -> OneConnection { self.inner.count.fetch_add(1, Ordering::SeqCst); @@ -458,3 +517,318 @@ impl Drop for OneConnection { } } } + +#[cfg(test)] +mod tests { + use std::{collections::BTreeMap, sync::Arc, time::Duration}; + + use iroh::{ + discovery::static_provider::StaticProvider, + endpoint::Connection, + protocol::{AcceptError, ProtocolHandler, Router}, + NodeAddr, NodeId, SecretKey, Watcher, + }; + use n0_future::{io, stream, BufferedStreamExt, StreamExt}; + use n0_snafu::ResultExt; + use testresult::TestResult; + use tracing::trace; + + use super::{ConnectionPool, Options, PoolConnectError}; + use crate::util::connection_pool::OnConnected; + + const ECHO_ALPN: &[u8] = b"echo"; + + #[derive(Debug, Clone)] + struct Echo; + + impl ProtocolHandler for Echo { + async fn accept(&self, connection: Connection) -> Result<(), AcceptError> { + let conn_id = connection.stable_id(); + let id = connection.remote_node_id().map_err(AcceptError::from_err)?; + trace!(%id, %conn_id, "Accepting echo connection"); + loop { + match connection.accept_bi().await { + Ok((mut send, mut recv)) => { + trace!(%id, %conn_id, "Accepted echo request"); + tokio::io::copy(&mut recv, &mut send).await?; + send.finish().map_err(AcceptError::from_err)?; + } + Err(e) => { + trace!(%id, %conn_id, "Failed to accept echo request {e}"); + break; + } + } + } + Ok(()) + } + } + + async fn echo_client(conn: &Connection, text: &[u8]) -> n0_snafu::Result> { + let conn_id = conn.stable_id(); + let id = conn.remote_node_id().e()?; + trace!(%id, %conn_id, "Sending echo request"); + let (mut send, mut recv) = conn.open_bi().await.e()?; + send.write_all(text).await.e()?; + send.finish().e()?; + let response = recv.read_to_end(1000).await.e()?; + trace!(%id, %conn_id, "Received echo response"); + Ok(response) + } + + async fn echo_server() -> TestResult<(NodeAddr, Router)> { + let endpoint = iroh::Endpoint::builder() + .alpns(vec![ECHO_ALPN.to_vec()]) + .bind() + .await?; + endpoint.home_relay().initialized().await; + let addr = endpoint.node_addr().initialized().await; + let router = iroh::protocol::Router::builder(endpoint) + .accept(ECHO_ALPN, Echo) + .spawn(); + + Ok((addr, router)) + } + + async fn echo_servers(n: usize) -> TestResult<(Vec, Vec, StaticProvider)> { + let res = stream::iter(0..n) + .map(|_| echo_server()) + .buffered_unordered(16) + .collect::>() + .await; + let res: Vec<(NodeAddr, Router)> = res.into_iter().collect::>>()?; + let (addrs, routers): (Vec<_>, Vec<_>) = res.into_iter().unzip(); + let ids = addrs.iter().map(|a| a.node_id).collect::>(); + let discovery = StaticProvider::from_node_info(addrs); + Ok((ids, routers, discovery)) + } + + async fn shutdown_routers(routers: Vec) { + stream::iter(routers) + .for_each_concurrent(16, |router| async move { + let _ = router.shutdown().await; + }) + .await; + } + + fn test_options() -> Options { + Options { + idle_timeout: Duration::from_millis(100), + connect_timeout: Duration::from_secs(5), + max_connections: 32, + on_connected: None, + } + } + + struct EchoClient { + pool: ConnectionPool, + } + + impl EchoClient { + async fn echo( + &self, + id: NodeId, + text: Vec, + ) -> Result), n0_snafu::Error>, PoolConnectError> { + let conn = self.pool.get_or_connect(id).await?; + let id = conn.stable_id(); + match echo_client(&conn, &text).await { + Ok(res) => Ok(Ok((id, res))), + Err(e) => Ok(Err(e)), + } + } + } + + #[tokio::test] + // #[traced_test] + async fn connection_pool_errors() -> TestResult<()> { + // set up static discovery for all addrs + let discovery = StaticProvider::new(); + let endpoint = iroh::Endpoint::builder() + .discovery(discovery.clone()) + .bind() + .await?; + let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options()); + let client = EchoClient { pool }; + { + let non_existing = SecretKey::from_bytes(&[0; 32]).public(); + let res = client.echo(non_existing, b"Hello, world!".to_vec()).await; + // trying to connect to a non-existing id will fail with ConnectError + // because we don't have any information about the node + assert!(matches!(res, Err(PoolConnectError::ConnectError { .. }))); + } + { + let non_listening = SecretKey::from_bytes(&[0; 32]).public(); + // make up fake node info + discovery.add_node_info(NodeAddr { + node_id: non_listening, + relay_url: None, + direct_addresses: vec!["127.0.0.1:12121".parse().unwrap()] + .into_iter() + .collect(), + }); + // trying to connect to an id for which we have info, but the other + // end is not listening, will lead to a timeout. + let res = client.echo(non_listening, b"Hello, world!".to_vec()).await; + assert!(matches!(res, Err(PoolConnectError::Timeout))); + } + Ok(()) + } + + #[tokio::test] + // #[traced_test] + async fn connection_pool_smoke() -> TestResult<()> { + let n = 32; + let (ids, routers, discovery) = echo_servers(n).await?; + // build a client endpoint that can resolve all the node ids + let endpoint = iroh::Endpoint::builder() + .discovery(discovery.clone()) + .bind() + .await?; + let pool = ConnectionPool::new(endpoint.clone(), ECHO_ALPN, test_options()); + let client = EchoClient { pool }; + let mut connection_ids = BTreeMap::new(); + let msg = b"Hello, pool!".to_vec(); + for id in &ids { + let (cid1, res) = client.echo(*id, msg.clone()).await??; + assert_eq!(res, msg); + let (cid2, res) = client.echo(*id, msg.clone()).await??; + assert_eq!(res, msg); + assert_eq!(cid1, cid2); + connection_ids.insert(id, cid1); + } + tokio::time::sleep(Duration::from_millis(1000)).await; + for id in &ids { + let cid1 = *connection_ids.get(id).expect("Connection ID not found"); + let (cid2, res) = client.echo(*id, msg.clone()).await??; + assert_eq!(res, msg); + assert_ne!(cid1, cid2); + } + shutdown_routers(routers).await; + Ok(()) + } + + /// Tests that idle connections are being reclaimed to make room if we hit the + /// maximum connection limit. + #[tokio::test] + // #[traced_test] + async fn connection_pool_idle() -> TestResult<()> { + let n = 32; + let (ids, routers, discovery) = echo_servers(n).await?; + // build a client endpoint that can resolve all the node ids + let endpoint = iroh::Endpoint::builder() + .discovery(discovery.clone()) + .bind() + .await?; + let pool = ConnectionPool::new( + endpoint.clone(), + ECHO_ALPN, + Options { + idle_timeout: Duration::from_secs(100), + max_connections: 8, + ..test_options() + }, + ); + let client = EchoClient { pool }; + let msg = b"Hello, pool!".to_vec(); + for id in &ids { + let (_, res) = client.echo(*id, msg.clone()).await??; + assert_eq!(res, msg); + } + shutdown_routers(routers).await; + Ok(()) + } + + /// Uses an on_connected callback that just errors out every time. + /// + /// This is a basic smoke test that on_connected gets called at all. + #[tokio::test] + // #[traced_test] + async fn on_connected_error() -> TestResult<()> { + let n = 1; + let (ids, routers, discovery) = echo_servers(n).await?; + let endpoint = iroh::Endpoint::builder() + .discovery(discovery) + .bind() + .await?; + let on_connected: OnConnected = + Arc::new(|_, _| Box::pin(async { Err(io::Error::other("on_connect failed")) })); + let pool = ConnectionPool::new( + endpoint, + ECHO_ALPN, + Options { + on_connected: Some(on_connected), + ..test_options() + }, + ); + let client = EchoClient { pool }; + let msg = b"Hello, pool!".to_vec(); + for id in &ids { + let res = client.echo(*id, msg.clone()).await; + assert!(matches!(res, Err(PoolConnectError::OnConnectError { .. }))); + } + shutdown_routers(routers).await; + Ok(()) + } + + /// Uses an on_connected callback that delays for a long time. + /// + /// This checks that the pool timeout includes on_connected delay. + #[tokio::test] + // #[traced_test] + async fn on_connected_timeout() -> TestResult<()> { + let n = 1; + let (ids, routers, discovery) = echo_servers(n).await?; + let endpoint = iroh::Endpoint::builder() + .discovery(discovery) + .bind() + .await?; + let on_connected: OnConnected = Arc::new(|_, _| { + Box::pin(async { + tokio::time::sleep(Duration::from_secs(20)).await; + Ok(()) + }) + }); + let pool = ConnectionPool::new( + endpoint, + ECHO_ALPN, + Options { + on_connected: Some(on_connected), + ..test_options() + }, + ); + let client = EchoClient { pool }; + let msg = b"Hello, pool!".to_vec(); + for id in &ids { + let res = client.echo(*id, msg.clone()).await; + assert!(matches!(res, Err(PoolConnectError::Timeout))); + } + shutdown_routers(routers).await; + Ok(()) + } + + /// Check that when a connection is closed, the pool will give you a new + /// connection next time you want one. + /// + /// This test fails if the connection watch is disabled. + #[tokio::test] + // #[traced_test] + async fn watch_close() -> TestResult<()> { + let n = 1; + let (ids, routers, discovery) = echo_servers(n).await?; + let endpoint = iroh::Endpoint::builder() + .discovery(discovery) + .bind() + .await?; + + let pool = ConnectionPool::new(endpoint, ECHO_ALPN, test_options()); + let conn = pool.get_or_connect(ids[0]).await?; + let cid1 = conn.stable_id(); + conn.close(0u32.into(), b"test"); + tokio::time::sleep(Duration::from_millis(500)).await; + let conn = pool.get_or_connect(ids[0]).await?; + let cid2 = conn.stable_id(); + assert_ne!(cid1, cid2); + shutdown_routers(routers).await; + Ok(()) + } +}