@@ -51,7 +51,7 @@ pub struct Options {
5151 /// An example usage could be to wait for the connection to become direct before handing
5252 /// it out to the user.
5353 #[ debug( skip) ]
54- pub on_connect : Option < OnConnected > ,
54+ pub on_connected : Option < OnConnected > ,
5555}
5656
5757impl Default for Options {
@@ -60,7 +60,7 @@ impl Default for Options {
6060 idle_timeout : Duration :: from_secs ( 5 ) ,
6161 connect_timeout : Duration :: from_secs ( 1 ) ,
6262 max_connections : 1024 ,
63- on_connect : None ,
63+ on_connected : None ,
6464 }
6565 }
6666}
@@ -167,7 +167,7 @@ impl Context {
167167 . connect ( node_id, & context2. alpn )
168168 . await
169169 . map_err ( PoolConnectError :: from) ?;
170- if let Some ( on_connect) = & context2. options . on_connect {
170+ if let Some ( on_connect) = & context2. options . on_connected {
171171 on_connect ( & context2. endpoint , & conn)
172172 . await
173173 . map_err ( PoolConnectError :: from) ?;
@@ -519,20 +519,22 @@ impl Drop for OneConnection {
519519
520520#[ cfg( test) ]
521521mod tests {
522- use std:: { collections:: BTreeMap , time:: Duration } ;
522+ use std:: { collections:: BTreeMap , sync :: Arc , time:: Duration } ;
523523
524524 use iroh:: {
525525 discovery:: static_provider:: StaticProvider ,
526526 endpoint:: Connection ,
527527 protocol:: { AcceptError , ProtocolHandler , Router } ,
528528 NodeAddr , NodeId , SecretKey , Watcher ,
529529 } ;
530- use n0_future:: { stream, BufferedStreamExt , StreamExt } ;
530+ use n0_future:: { io , stream, BufferedStreamExt , StreamExt } ;
531531 use n0_snafu:: ResultExt ;
532532 use testresult:: TestResult ;
533533 use tracing:: trace;
534+ use tracing_test:: traced_test;
534535
535536 use super :: { ConnectionPool , Options , PoolConnectError } ;
537+ use crate :: util:: connection_pool:: OnConnected ;
536538
537539 const ECHO_ALPN : & [ u8 ] = b"echo" ;
538540
@@ -586,22 +588,33 @@ mod tests {
586588 Ok ( ( addr, router) )
587589 }
588590
589- async fn echo_servers ( n : usize ) -> TestResult < Vec < ( NodeAddr , Router ) > > {
590- stream:: iter ( 0 ..n)
591+ async fn echo_servers ( n : usize ) -> TestResult < ( Vec < NodeId > , Vec < Router > , StaticProvider ) > {
592+ let res = stream:: iter ( 0 ..n)
591593 . map ( |_| echo_server ( ) )
592594 . buffered_unordered ( 16 )
593595 . collect :: < Vec < _ > > ( )
594- . await
595- . into_iter ( )
596- . collect ( )
596+ . await ;
597+ let res: Vec < ( NodeAddr , Router ) > = res. into_iter ( ) . collect :: < TestResult < Vec < _ > > > ( ) ?;
598+ let ( addrs, routers) : ( Vec < _ > , Vec < _ > ) = res. into_iter ( ) . unzip ( ) ;
599+ let ids = addrs. iter ( ) . map ( |a| a. node_id ) . collect :: < Vec < _ > > ( ) ;
600+ let discovery = StaticProvider :: from_node_info ( addrs) ;
601+ Ok ( ( ids, routers, discovery) )
602+ }
603+
604+ async fn shutdown_routers ( routers : Vec < Router > ) {
605+ stream:: iter ( routers)
606+ . for_each_concurrent ( 16 , |router| async move {
607+ let _ = router. shutdown ( ) . await ;
608+ } )
609+ . await ;
597610 }
598611
599612 fn test_options ( ) -> Options {
600613 Options {
601614 idle_timeout : Duration :: from_millis ( 100 ) ,
602615 connect_timeout : Duration :: from_secs ( 2 ) ,
603616 max_connections : 32 ,
604- on_connect : None ,
617+ on_connected : None ,
605618 }
606619 }
607620
@@ -625,12 +638,8 @@ mod tests {
625638 }
626639
627640 #[ tokio:: test]
641+ #[ traced_test]
628642 async fn connection_pool_errors ( ) -> TestResult < ( ) > {
629- let filter = tracing_subscriber:: EnvFilter :: from_default_env ( ) ;
630- tracing_subscriber:: fmt ( )
631- . with_env_filter ( filter)
632- . try_init ( )
633- . ok ( ) ;
634643 // set up static discovery for all addrs
635644 let discovery = StaticProvider :: new ( ) ;
636645 let endpoint = iroh:: Endpoint :: builder ( )
@@ -665,20 +674,10 @@ mod tests {
665674 }
666675
667676 #[ tokio:: test]
677+ #[ traced_test]
668678 async fn connection_pool_smoke ( ) -> TestResult < ( ) > {
669- let filter = tracing_subscriber:: EnvFilter :: from_default_env ( ) ;
670- tracing_subscriber:: fmt ( )
671- . with_env_filter ( filter)
672- . try_init ( )
673- . ok ( ) ;
674679 let n = 32 ;
675- let nodes = echo_servers ( n) . await ?;
676- let ids = nodes
677- . iter ( )
678- . map ( |( addr, _) | addr. node_id )
679- . collect :: < Vec < _ > > ( ) ;
680- // set up static discovery for all addrs
681- let discovery = StaticProvider :: from_node_info ( nodes. iter ( ) . map ( |( addr, _) | addr. clone ( ) ) ) ;
680+ let ( ids, routers, discovery) = echo_servers ( n) . await ?;
682681 // build a client endpoint that can resolve all the node ids
683682 let endpoint = iroh:: Endpoint :: builder ( )
684683 . discovery ( discovery. clone ( ) )
@@ -687,7 +686,7 @@ mod tests {
687686 let pool = ConnectionPool :: new ( endpoint. clone ( ) , ECHO_ALPN , test_options ( ) ) ;
688687 let client = EchoClient { pool } ;
689688 let mut connection_ids = BTreeMap :: new ( ) ;
690- let msg = b"Hello, world !" . to_vec ( ) ;
689+ let msg = b"Hello, pool !" . to_vec ( ) ;
691690 for id in & ids {
692691 let ( cid1, res) = client. echo ( * id, msg. clone ( ) ) . await ??;
693692 assert_eq ! ( res, msg) ;
@@ -703,26 +702,17 @@ mod tests {
703702 assert_eq ! ( res, msg) ;
704703 assert_ne ! ( cid1, cid2) ;
705704 }
705+ shutdown_routers ( routers) . await ;
706706 Ok ( ( ) )
707707 }
708708
709709 /// Tests that idle connections are being reclaimed to make room if we hit the
710710 /// maximum connection limit.
711711 #[ tokio:: test]
712+ #[ traced_test]
712713 async fn connection_pool_idle ( ) -> TestResult < ( ) > {
713- let filter = tracing_subscriber:: EnvFilter :: from_default_env ( ) ;
714- tracing_subscriber:: fmt ( )
715- . with_env_filter ( filter)
716- . try_init ( )
717- . ok ( ) ;
718714 let n = 32 ;
719- let nodes = echo_servers ( n) . await ?;
720- let ids = nodes
721- . iter ( )
722- . map ( |( addr, _) | addr. node_id )
723- . collect :: < Vec < _ > > ( ) ;
724- // set up static discovery for all addrs
725- let discovery = StaticProvider :: from_node_info ( nodes. iter ( ) . map ( |( addr, _) | addr. clone ( ) ) ) ;
715+ let ( ids, routers, discovery) = echo_servers ( n) . await ?;
726716 // build a client endpoint that can resolve all the node ids
727717 let endpoint = iroh:: Endpoint :: builder ( )
728718 . discovery ( discovery. clone ( ) )
@@ -738,11 +728,80 @@ mod tests {
738728 } ,
739729 ) ;
740730 let client = EchoClient { pool } ;
741- let msg = b"Hello, world !" . to_vec ( ) ;
731+ let msg = b"Hello, pool !" . to_vec ( ) ;
742732 for id in & ids {
743733 let ( _, res) = client. echo ( * id, msg. clone ( ) ) . await ??;
744734 assert_eq ! ( res, msg) ;
745735 }
736+ shutdown_routers ( routers) . await ;
737+ Ok ( ( ) )
738+ }
739+
740+ /// Uses an on_connected callback that just errors out every time.
741+ ///
742+ /// This is a basic smoke test that on_connected gets called at all.
743+ #[ tokio:: test]
744+ #[ traced_test]
745+ async fn on_connected_error ( ) -> TestResult < ( ) > {
746+ let n = 1 ;
747+ let ( ids, routers, discovery) = echo_servers ( n) . await ?;
748+ let endpoint = iroh:: Endpoint :: builder ( )
749+ . discovery ( discovery)
750+ . bind ( )
751+ . await ?;
752+ let on_connected: OnConnected =
753+ Arc :: new ( |_, _| Box :: pin ( async { Err ( io:: Error :: other ( "on_connect failed" ) ) } ) ) ;
754+ let pool = ConnectionPool :: new (
755+ endpoint,
756+ ECHO_ALPN ,
757+ Options {
758+ on_connected : Some ( on_connected) ,
759+ ..test_options ( )
760+ } ,
761+ ) ;
762+ let client = EchoClient { pool } ;
763+ let msg = b"Hello, pool!" . to_vec ( ) ;
764+ for id in & ids {
765+ let res = client. echo ( * id, msg. clone ( ) ) . await ;
766+ assert ! ( matches!( res, Err ( PoolConnectError :: OnConnectError { .. } ) ) ) ;
767+ }
768+ shutdown_routers ( routers) . await ;
769+ Ok ( ( ) )
770+ }
771+
772+ /// Uses an on_connected callback that delays for a long time.
773+ ///
774+ /// This checks that the pool timeout includes on_connected delay.
775+ #[ tokio:: test]
776+ #[ traced_test]
777+ async fn on_connected_timeout ( ) -> TestResult < ( ) > {
778+ let n = 1 ;
779+ let ( ids, routers, discovery) = echo_servers ( n) . await ?;
780+ let endpoint = iroh:: Endpoint :: builder ( )
781+ . discovery ( discovery)
782+ . bind ( )
783+ . await ?;
784+ let on_connected: OnConnected = Arc :: new ( |_, _| {
785+ Box :: pin ( async {
786+ tokio:: time:: sleep ( Duration :: from_secs ( 2 ) ) . await ;
787+ Ok ( ( ) )
788+ } )
789+ } ) ;
790+ let pool = ConnectionPool :: new (
791+ endpoint,
792+ ECHO_ALPN ,
793+ Options {
794+ on_connected : Some ( on_connected) ,
795+ ..test_options ( )
796+ } ,
797+ ) ;
798+ let client = EchoClient { pool } ;
799+ let msg = b"Hello, pool!" . to_vec ( ) ;
800+ for id in & ids {
801+ let res = client. echo ( * id, msg. clone ( ) ) . await ;
802+ assert ! ( matches!( res, Err ( PoolConnectError :: Timeout { .. } ) ) ) ;
803+ }
804+ shutdown_routers ( routers) . await ;
746805 Ok ( ( ) )
747806 }
748807}
0 commit comments