@@ -477,3 +477,233 @@ impl Drop for OneConnection {
477477 }
478478 }
479479}
480+
481+ #[ cfg( test) ]
482+ mod tests {
483+ use std:: { collections:: BTreeMap , time:: Duration } ;
484+
485+ use iroh:: {
486+ discovery:: static_provider:: StaticProvider ,
487+ endpoint:: Connection ,
488+ protocol:: { AcceptError , ProtocolHandler , Router } ,
489+ NodeAddr , NodeId , SecretKey , Watcher ,
490+ } ;
491+ use n0_future:: { stream, BufferedStreamExt , StreamExt } ;
492+ use n0_snafu:: ResultExt ;
493+ use testresult:: TestResult ;
494+ use tracing:: trace;
495+
496+ use super :: { ConnectionPool , Options , PoolConnectError } ;
497+
498+ const ECHO_ALPN : & [ u8 ] = b"echo" ;
499+
500+ #[ derive( Debug , Clone ) ]
501+ struct Echo ;
502+
503+ impl ProtocolHandler for Echo {
504+ async fn accept ( & self , connection : Connection ) -> Result < ( ) , AcceptError > {
505+ let conn_id = connection. stable_id ( ) ;
506+ let id = connection. remote_node_id ( ) . map_err ( AcceptError :: from_err) ?;
507+ trace ! ( %id, %conn_id, "Accepting echo connection" ) ;
508+ loop {
509+ match connection. accept_bi ( ) . await {
510+ Ok ( ( mut send, mut recv) ) => {
511+ trace ! ( %id, %conn_id, "Accepted echo request" ) ;
512+ tokio:: io:: copy ( & mut recv, & mut send) . await ?;
513+ send. finish ( ) . map_err ( AcceptError :: from_err) ?;
514+ }
515+ Err ( e) => {
516+ trace ! ( %id, %conn_id, "Failed to accept echo request {e}" ) ;
517+ break ;
518+ }
519+ }
520+ }
521+ Ok ( ( ) )
522+ }
523+ }
524+
525+ async fn echo_client ( conn : & Connection , text : & [ u8 ] ) -> n0_snafu:: Result < Vec < u8 > > {
526+ let conn_id = conn. stable_id ( ) ;
527+ let id = conn. remote_node_id ( ) . e ( ) ?;
528+ trace ! ( %id, %conn_id, "Sending echo request" ) ;
529+ let ( mut send, mut recv) = conn. open_bi ( ) . await . e ( ) ?;
530+ send. write_all ( text) . await . e ( ) ?;
531+ send. finish ( ) . e ( ) ?;
532+ let response = recv. read_to_end ( 1000 ) . await . e ( ) ?;
533+ trace ! ( %id, %conn_id, "Received echo response" ) ;
534+ Ok ( response)
535+ }
536+
537+ async fn echo_server ( ) -> TestResult < ( NodeAddr , Router ) > {
538+ let endpoint = iroh:: Endpoint :: builder ( )
539+ . alpns ( vec ! [ ECHO_ALPN . to_vec( ) ] )
540+ . bind ( )
541+ . await ?;
542+ let addr = endpoint. node_addr ( ) . initialized ( ) . await ;
543+ let router = iroh:: protocol:: Router :: builder ( endpoint)
544+ . accept ( ECHO_ALPN , Echo )
545+ . spawn ( ) ;
546+
547+ Ok ( ( addr, router) )
548+ }
549+
550+ async fn echo_servers ( n : usize ) -> TestResult < Vec < ( NodeAddr , Router ) > > {
551+ stream:: iter ( 0 ..n)
552+ . map ( |_| echo_server ( ) )
553+ . buffered_unordered ( 16 )
554+ . collect :: < Vec < _ > > ( )
555+ . await
556+ . into_iter ( )
557+ . collect ( )
558+ }
559+
560+ fn test_options ( ) -> Options {
561+ Options {
562+ idle_timeout : Duration :: from_millis ( 100 ) ,
563+ connect_timeout : Duration :: from_secs ( 2 ) ,
564+ max_connections : 32 ,
565+ on_connect : None ,
566+ }
567+ }
568+
569+ struct EchoClient {
570+ pool : ConnectionPool ,
571+ }
572+
573+ impl EchoClient {
574+ async fn echo (
575+ & self ,
576+ id : NodeId ,
577+ text : Vec < u8 > ,
578+ ) -> Result < Result < ( usize , Vec < u8 > ) , n0_snafu:: Error > , PoolConnectError > {
579+ let conn = self . pool . get_or_connect ( id) . await ?;
580+ let id = conn. stable_id ( ) ;
581+ match echo_client ( & conn, & text) . await {
582+ Ok ( res) => Ok ( Ok ( ( id, res) ) ) ,
583+ Err ( e) => Ok ( Err ( e) ) ,
584+ }
585+ }
586+ }
587+
588+ #[ tokio:: test]
589+ async fn connection_pool_errors ( ) -> TestResult < ( ) > {
590+ let filter = tracing_subscriber:: EnvFilter :: from_default_env ( ) ;
591+ tracing_subscriber:: fmt ( )
592+ . with_env_filter ( filter)
593+ . try_init ( )
594+ . ok ( ) ;
595+ // set up static discovery for all addrs
596+ let discovery = StaticProvider :: new ( ) ;
597+ let endpoint = iroh:: Endpoint :: builder ( )
598+ . discovery ( discovery. clone ( ) )
599+ . bind ( )
600+ . await ?;
601+ let pool = ConnectionPool :: new ( endpoint, ECHO_ALPN , test_options ( ) ) ;
602+ let client = EchoClient { pool } ;
603+ {
604+ let non_existing = SecretKey :: from_bytes ( & [ 0 ; 32 ] ) . public ( ) ;
605+ let res = client. echo ( non_existing, b"Hello, world!" . to_vec ( ) ) . await ;
606+ // trying to connect to a non-existing id will fail with ConnectError
607+ // because we don't have any information about the node
608+ assert ! ( matches!( res, Err ( PoolConnectError :: ConnectError { .. } ) ) ) ;
609+ }
610+ {
611+ let non_listening = SecretKey :: from_bytes ( & [ 0 ; 32 ] ) . public ( ) ;
612+ // make up fake node info
613+ discovery. add_node_info ( NodeAddr {
614+ node_id : non_listening,
615+ relay_url : None ,
616+ direct_addresses : vec ! [ "127.0.0.1:12121" . parse( ) . unwrap( ) ]
617+ . into_iter ( )
618+ . collect ( ) ,
619+ } ) ;
620+ // trying to connect to an id for which we have info, but the other
621+ // end is not listening, will lead to a timeout.
622+ let res = client. echo ( non_listening, b"Hello, world!" . to_vec ( ) ) . await ;
623+ assert ! ( matches!( res, Err ( PoolConnectError :: Timeout ) ) ) ;
624+ }
625+ Ok ( ( ) )
626+ }
627+
628+ #[ tokio:: test]
629+ async fn connection_pool_smoke ( ) -> TestResult < ( ) > {
630+ let filter = tracing_subscriber:: EnvFilter :: from_default_env ( ) ;
631+ tracing_subscriber:: fmt ( )
632+ . with_env_filter ( filter)
633+ . try_init ( )
634+ . ok ( ) ;
635+ let n = 32 ;
636+ let nodes = echo_servers ( n) . await ?;
637+ let ids = nodes
638+ . iter ( )
639+ . map ( |( addr, _) | addr. node_id )
640+ . collect :: < Vec < _ > > ( ) ;
641+ // set up static discovery for all addrs
642+ let discovery = StaticProvider :: from_node_info ( nodes. iter ( ) . map ( |( addr, _) | addr. clone ( ) ) ) ;
643+ // build a client endpoint that can resolve all the node ids
644+ let endpoint = iroh:: Endpoint :: builder ( )
645+ . discovery ( discovery. clone ( ) )
646+ . bind ( )
647+ . await ?;
648+ let pool = ConnectionPool :: new ( endpoint. clone ( ) , ECHO_ALPN , test_options ( ) ) ;
649+ let client = EchoClient { pool } ;
650+ let mut connection_ids = BTreeMap :: new ( ) ;
651+ let msg = b"Hello, world!" . to_vec ( ) ;
652+ for id in & ids {
653+ let ( cid1, res) = client. echo ( * id, msg. clone ( ) ) . await ??;
654+ assert_eq ! ( res, msg) ;
655+ let ( cid2, res) = client. echo ( * id, msg. clone ( ) ) . await ??;
656+ assert_eq ! ( res, msg) ;
657+ assert_eq ! ( cid1, cid2) ;
658+ connection_ids. insert ( id, cid1) ;
659+ }
660+ tokio:: time:: sleep ( Duration :: from_millis ( 1000 ) ) . await ;
661+ for id in & ids {
662+ let cid1 = * connection_ids. get ( id) . expect ( "Connection ID not found" ) ;
663+ let ( cid2, res) = client. echo ( * id, msg. clone ( ) ) . await ??;
664+ assert_eq ! ( res, msg) ;
665+ assert_ne ! ( cid1, cid2) ;
666+ }
667+ Ok ( ( ) )
668+ }
669+
670+ /// Tests that idle connections are being reclaimed to make room if we hit the
671+ /// maximum connection limit.
672+ #[ tokio:: test]
673+ async fn connection_pool_idle ( ) -> TestResult < ( ) > {
674+ let filter = tracing_subscriber:: EnvFilter :: from_default_env ( ) ;
675+ tracing_subscriber:: fmt ( )
676+ . with_env_filter ( filter)
677+ . try_init ( )
678+ . ok ( ) ;
679+ let n = 32 ;
680+ let nodes = echo_servers ( n) . await ?;
681+ let ids = nodes
682+ . iter ( )
683+ . map ( |( addr, _) | addr. node_id )
684+ . collect :: < Vec < _ > > ( ) ;
685+ // set up static discovery for all addrs
686+ let discovery = StaticProvider :: from_node_info ( nodes. iter ( ) . map ( |( addr, _) | addr. clone ( ) ) ) ;
687+ // build a client endpoint that can resolve all the node ids
688+ let endpoint = iroh:: Endpoint :: builder ( )
689+ . discovery ( discovery. clone ( ) )
690+ . bind ( )
691+ . await ?;
692+ let pool = ConnectionPool :: new (
693+ endpoint. clone ( ) ,
694+ ECHO_ALPN ,
695+ Options {
696+ idle_timeout : Duration :: from_secs ( 100 ) ,
697+ max_connections : 8 ,
698+ ..test_options ( )
699+ } ,
700+ ) ;
701+ let client = EchoClient { pool } ;
702+ let msg = b"Hello, world!" . to_vec ( ) ;
703+ for id in & ids {
704+ let ( _, res) = client. echo ( * id, msg. clone ( ) ) . await ??;
705+ assert_eq ! ( res, msg) ;
706+ }
707+ Ok ( ( ) )
708+ }
709+ }
0 commit comments