1010//! the connection.
1111use std:: {
1212 collections:: { HashMap , VecDeque } ,
13+ io,
1314 ops:: Deref ,
1415 sync:: {
1516 atomic:: { AtomicUsize , Ordering } ,
@@ -18,7 +19,10 @@ use std::{
1819 time:: Duration ,
1920} ;
2021
21- use iroh:: { endpoint:: ConnectError , Endpoint , NodeId } ;
22+ use iroh:: {
23+ endpoint:: { ConnectError , Connection } ,
24+ Endpoint , NodeId ,
25+ } ;
2226use n0_future:: {
2327 future:: { self } ,
2428 FuturesUnordered , MaybeFuture , Stream , StreamExt ,
@@ -31,12 +35,23 @@ use tokio::sync::{
3135use tokio_util:: time:: FutureExt as TimeFutureExt ;
3236use tracing:: { debug, error, info, trace} ;
3337
38+ pub type OnConnected =
39+ Arc < dyn Fn ( & Endpoint , & Connection ) -> n0_future:: future:: Boxed < io:: Result < ( ) > > + Send + Sync > ;
40+
3441/// Configuration options for the connection pool
35- #[ derive( Debug , Clone , Copy ) ]
42+ #[ derive( derive_more :: Debug , Clone ) ]
3643pub struct Options {
44+ /// How long to keep idle connections around.
3745 pub idle_timeout : Duration ,
46+ /// Timeout for connect. This includes the time spent in on_connect, if set.
3847 pub connect_timeout : Duration ,
48+ /// Maximum number of connections to hand out.
3949 pub max_connections : usize ,
50+ /// An optional callback that can be used to wait for the connection to enter some state.
51+ /// An example usage could be to wait for the connection to become direct before handing
52+ /// it out to the user.
53+ #[ debug( skip) ]
54+ pub on_connect : Option < OnConnected > ,
4055}
4156
4257impl Default for Options {
@@ -45,6 +60,7 @@ impl Default for Options {
4560 idle_timeout : Duration :: from_secs ( 5 ) ,
4661 connect_timeout : Duration :: from_secs ( 1 ) ,
4762 max_connections : 1024 ,
63+ on_connect : None ,
4864 }
4965 }
5066}
@@ -88,6 +104,8 @@ pub enum PoolConnectError {
88104 TooManyConnections ,
89105 /// Error during connect
90106 ConnectError { source : Arc < ConnectError > } ,
107+ /// Error during on_connect callback
108+ OnConnectError { source : Arc < io:: Error > } ,
91109}
92110
93111impl From < ConnectError > for PoolConnectError {
@@ -98,6 +116,14 @@ impl From<ConnectError> for PoolConnectError {
98116 }
99117}
100118
119+ impl From < io:: Error > for PoolConnectError {
120+ fn from ( e : io:: Error ) -> Self {
121+ PoolConnectError :: OnConnectError {
122+ source : Arc :: new ( e) ,
123+ }
124+ }
125+ }
126+
101127/// Error when calling a fn on the [`ConnectionPool`].
102128///
103129/// The only thing that can go wrong is that the connection pool is shut down.
@@ -134,10 +160,23 @@ impl Context {
134160 ) {
135161 let context = self ;
136162
163+ let context2 = context. clone ( ) ;
164+ let conn_fut = async move {
165+ let conn = context2
166+ . endpoint
167+ . connect ( node_id, & context2. alpn )
168+ . await
169+ . map_err ( PoolConnectError :: from) ?;
170+ if let Some ( on_connect) = & context2. options . on_connect {
171+ on_connect ( & context2. endpoint , & conn)
172+ . await
173+ . map_err ( PoolConnectError :: from) ?;
174+ }
175+ Result :: < Connection , PoolConnectError > :: Ok ( conn)
176+ } ;
177+
137178 // Connect to the node
138- let state = context
139- . endpoint
140- . connect ( node_id, & context. alpn )
179+ let state = conn_fut
141180 . timeout ( context. options . connect_timeout )
142181 . await
143182 . map_err ( |_| PoolConnectError :: Timeout )
0 commit comments