Skip to content

Commit 8ba7d91

Browse files
committed
Add test for connection pool on_connected timeout.
Also shut down routers.
1 parent a32e159 commit 8ba7d91

File tree

1 file changed

+101
-42
lines changed

1 file changed

+101
-42
lines changed

src/util/connection_pool.rs

Lines changed: 101 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ pub struct Options {
5151
/// An example usage could be to wait for the connection to become direct before handing
5252
/// it out to the user.
5353
#[debug(skip)]
54-
pub on_connect: Option<OnConnected>,
54+
pub on_connected: Option<OnConnected>,
5555
}
5656

5757
impl Default for Options {
@@ -60,7 +60,7 @@ impl Default for Options {
6060
idle_timeout: Duration::from_secs(5),
6161
connect_timeout: Duration::from_secs(1),
6262
max_connections: 1024,
63-
on_connect: None,
63+
on_connected: None,
6464
}
6565
}
6666
}
@@ -167,7 +167,7 @@ impl Context {
167167
.connect(node_id, &context2.alpn)
168168
.await
169169
.map_err(PoolConnectError::from)?;
170-
if let Some(on_connect) = &context2.options.on_connect {
170+
if let Some(on_connect) = &context2.options.on_connected {
171171
on_connect(&context2.endpoint, &conn)
172172
.await
173173
.map_err(PoolConnectError::from)?;
@@ -519,20 +519,22 @@ impl Drop for OneConnection {
519519

520520
#[cfg(test)]
521521
mod tests {
522-
use std::{collections::BTreeMap, time::Duration};
522+
use std::{collections::BTreeMap, sync::Arc, time::Duration};
523523

524524
use iroh::{
525525
discovery::static_provider::StaticProvider,
526526
endpoint::Connection,
527527
protocol::{AcceptError, ProtocolHandler, Router},
528528
NodeAddr, NodeId, SecretKey, Watcher,
529529
};
530-
use n0_future::{stream, BufferedStreamExt, StreamExt};
530+
use n0_future::{io, stream, BufferedStreamExt, StreamExt};
531531
use n0_snafu::ResultExt;
532532
use testresult::TestResult;
533533
use tracing::trace;
534+
use tracing_test::traced_test;
534535

535536
use super::{ConnectionPool, Options, PoolConnectError};
537+
use crate::util::connection_pool::OnConnected;
536538

537539
const ECHO_ALPN: &[u8] = b"echo";
538540

@@ -586,22 +588,33 @@ mod tests {
586588
Ok((addr, router))
587589
}
588590

589-
async fn echo_servers(n: usize) -> TestResult<Vec<(NodeAddr, Router)>> {
590-
stream::iter(0..n)
591+
async fn echo_servers(n: usize) -> TestResult<(Vec<NodeId>, Vec<Router>, StaticProvider)> {
592+
let res = stream::iter(0..n)
591593
.map(|_| echo_server())
592594
.buffered_unordered(16)
593595
.collect::<Vec<_>>()
594-
.await
595-
.into_iter()
596-
.collect()
596+
.await;
597+
let res: Vec<(NodeAddr, Router)> = res.into_iter().collect::<TestResult<Vec<_>>>()?;
598+
let (addrs, routers): (Vec<_>, Vec<_>) = res.into_iter().unzip();
599+
let ids = addrs.iter().map(|a| a.node_id).collect::<Vec<_>>();
600+
let discovery = StaticProvider::from_node_info(addrs);
601+
Ok((ids, routers, discovery))
602+
}
603+
604+
async fn shutdown_routers(routers: Vec<Router>) {
605+
stream::iter(routers)
606+
.for_each_concurrent(16, |router| async move {
607+
let _ = router.shutdown().await;
608+
})
609+
.await;
597610
}
598611

599612
fn test_options() -> Options {
600613
Options {
601614
idle_timeout: Duration::from_millis(100),
602615
connect_timeout: Duration::from_secs(2),
603616
max_connections: 32,
604-
on_connect: None,
617+
on_connected: None,
605618
}
606619
}
607620

@@ -625,12 +638,8 @@ mod tests {
625638
}
626639

627640
#[tokio::test]
641+
#[traced_test]
628642
async fn connection_pool_errors() -> TestResult<()> {
629-
let filter = tracing_subscriber::EnvFilter::from_default_env();
630-
tracing_subscriber::fmt()
631-
.with_env_filter(filter)
632-
.try_init()
633-
.ok();
634643
// set up static discovery for all addrs
635644
let discovery = StaticProvider::new();
636645
let endpoint = iroh::Endpoint::builder()
@@ -665,20 +674,10 @@ mod tests {
665674
}
666675

667676
#[tokio::test]
677+
#[traced_test]
668678
async fn connection_pool_smoke() -> TestResult<()> {
669-
let filter = tracing_subscriber::EnvFilter::from_default_env();
670-
tracing_subscriber::fmt()
671-
.with_env_filter(filter)
672-
.try_init()
673-
.ok();
674679
let n = 32;
675-
let nodes = echo_servers(n).await?;
676-
let ids = nodes
677-
.iter()
678-
.map(|(addr, _)| addr.node_id)
679-
.collect::<Vec<_>>();
680-
// set up static discovery for all addrs
681-
let discovery = StaticProvider::from_node_info(nodes.iter().map(|(addr, _)| addr.clone()));
680+
let (ids, routers, discovery) = echo_servers(n).await?;
682681
// build a client endpoint that can resolve all the node ids
683682
let endpoint = iroh::Endpoint::builder()
684683
.discovery(discovery.clone())
@@ -687,7 +686,7 @@ mod tests {
687686
let pool = ConnectionPool::new(endpoint.clone(), ECHO_ALPN, test_options());
688687
let client = EchoClient { pool };
689688
let mut connection_ids = BTreeMap::new();
690-
let msg = b"Hello, world!".to_vec();
689+
let msg = b"Hello, pool!".to_vec();
691690
for id in &ids {
692691
let (cid1, res) = client.echo(*id, msg.clone()).await??;
693692
assert_eq!(res, msg);
@@ -703,26 +702,17 @@ mod tests {
703702
assert_eq!(res, msg);
704703
assert_ne!(cid1, cid2);
705704
}
705+
shutdown_routers(routers).await;
706706
Ok(())
707707
}
708708

709709
/// Tests that idle connections are being reclaimed to make room if we hit the
710710
/// maximum connection limit.
711711
#[tokio::test]
712+
#[traced_test]
712713
async fn connection_pool_idle() -> TestResult<()> {
713-
let filter = tracing_subscriber::EnvFilter::from_default_env();
714-
tracing_subscriber::fmt()
715-
.with_env_filter(filter)
716-
.try_init()
717-
.ok();
718714
let n = 32;
719-
let nodes = echo_servers(n).await?;
720-
let ids = nodes
721-
.iter()
722-
.map(|(addr, _)| addr.node_id)
723-
.collect::<Vec<_>>();
724-
// set up static discovery for all addrs
725-
let discovery = StaticProvider::from_node_info(nodes.iter().map(|(addr, _)| addr.clone()));
715+
let (ids, routers, discovery) = echo_servers(n).await?;
726716
// build a client endpoint that can resolve all the node ids
727717
let endpoint = iroh::Endpoint::builder()
728718
.discovery(discovery.clone())
@@ -738,11 +728,80 @@ mod tests {
738728
},
739729
);
740730
let client = EchoClient { pool };
741-
let msg = b"Hello, world!".to_vec();
731+
let msg = b"Hello, pool!".to_vec();
742732
for id in &ids {
743733
let (_, res) = client.echo(*id, msg.clone()).await??;
744734
assert_eq!(res, msg);
745735
}
736+
shutdown_routers(routers).await;
737+
Ok(())
738+
}
739+
740+
/// Uses an on_connected callback that just errors out every time.
741+
///
742+
/// This is a basic smoke test that on_connected gets called at all.
743+
#[tokio::test]
744+
#[traced_test]
745+
async fn on_connected_error() -> TestResult<()> {
746+
let n = 1;
747+
let (ids, routers, discovery) = echo_servers(n).await?;
748+
let endpoint = iroh::Endpoint::builder()
749+
.discovery(discovery)
750+
.bind()
751+
.await?;
752+
let on_connected: OnConnected =
753+
Arc::new(|_, _| Box::pin(async { Err(io::Error::other("on_connect failed")) }));
754+
let pool = ConnectionPool::new(
755+
endpoint,
756+
ECHO_ALPN,
757+
Options {
758+
on_connected: Some(on_connected),
759+
..test_options()
760+
},
761+
);
762+
let client = EchoClient { pool };
763+
let msg = b"Hello, pool!".to_vec();
764+
for id in &ids {
765+
let res = client.echo(*id, msg.clone()).await;
766+
assert!(matches!(res, Err(PoolConnectError::OnConnectError { .. })));
767+
}
768+
shutdown_routers(routers).await;
769+
Ok(())
770+
}
771+
772+
/// Uses an on_connected callback that delays for a long time.
773+
///
774+
/// This checks that the pool timeout includes on_connected delay.
775+
#[tokio::test]
776+
#[traced_test]
777+
async fn on_connected_timeout() -> TestResult<()> {
778+
let n = 1;
779+
let (ids, routers, discovery) = echo_servers(n).await?;
780+
let endpoint = iroh::Endpoint::builder()
781+
.discovery(discovery)
782+
.bind()
783+
.await?;
784+
let on_connected: OnConnected = Arc::new(|_, _| {
785+
Box::pin(async {
786+
tokio::time::sleep(Duration::from_secs(2)).await;
787+
Ok(())
788+
})
789+
});
790+
let pool = ConnectionPool::new(
791+
endpoint,
792+
ECHO_ALPN,
793+
Options {
794+
on_connected: Some(on_connected),
795+
..test_options()
796+
},
797+
);
798+
let client = EchoClient { pool };
799+
let msg = b"Hello, pool!".to_vec();
800+
for id in &ids {
801+
let res = client.echo(*id, msg.clone()).await;
802+
assert!(matches!(res, Err(PoolConnectError::Timeout { .. })));
803+
}
804+
shutdown_routers(routers).await;
746805
Ok(())
747806
}
748807
}

0 commit comments

Comments
 (0)