Skip to content

Commit a32e159

Browse files
committed
Implement on_connect cb.
1 parent 960bdfd commit a32e159

File tree

1 file changed

+44
-5
lines changed

1 file changed

+44
-5
lines changed

src/util/connection_pool.rs

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
//! the connection.
1111
use std::{
1212
collections::{HashMap, VecDeque},
13+
io,
1314
ops::Deref,
1415
sync::{
1516
atomic::{AtomicUsize, Ordering},
@@ -18,7 +19,10 @@ use std::{
1819
time::Duration,
1920
};
2021

21-
use iroh::{endpoint::ConnectError, Endpoint, NodeId};
22+
use iroh::{
23+
endpoint::{ConnectError, Connection},
24+
Endpoint, NodeId,
25+
};
2226
use n0_future::{
2327
future::{self},
2428
FuturesUnordered, MaybeFuture, Stream, StreamExt,
@@ -31,12 +35,23 @@ use tokio::sync::{
3135
use tokio_util::time::FutureExt as TimeFutureExt;
3236
use tracing::{debug, error, info, trace};
3337

38+
pub type OnConnected =
39+
Arc<dyn Fn(&Endpoint, &Connection) -> n0_future::future::Boxed<io::Result<()>> + Send + Sync>;
40+
3441
/// Configuration options for the connection pool
35-
#[derive(Debug, Clone, Copy)]
42+
#[derive(derive_more::Debug, Clone)]
3643
pub struct Options {
44+
/// How long to keep idle connections around.
3745
pub idle_timeout: Duration,
46+
/// Timeout for connect. This includes the time spent in on_connect, if set.
3847
pub connect_timeout: Duration,
48+
/// Maximum number of connections to hand out.
3949
pub max_connections: usize,
50+
/// An optional callback that can be used to wait for the connection to enter some state.
51+
/// An example usage could be to wait for the connection to become direct before handing
52+
/// it out to the user.
53+
#[debug(skip)]
54+
pub on_connect: Option<OnConnected>,
4055
}
4156

4257
impl Default for Options {
@@ -45,6 +60,7 @@ impl Default for Options {
4560
idle_timeout: Duration::from_secs(5),
4661
connect_timeout: Duration::from_secs(1),
4762
max_connections: 1024,
63+
on_connect: None,
4864
}
4965
}
5066
}
@@ -88,6 +104,8 @@ pub enum PoolConnectError {
88104
TooManyConnections,
89105
/// Error during connect
90106
ConnectError { source: Arc<ConnectError> },
107+
/// Error during on_connect callback
108+
OnConnectError { source: Arc<io::Error> },
91109
}
92110

93111
impl From<ConnectError> for PoolConnectError {
@@ -98,6 +116,14 @@ impl From<ConnectError> for PoolConnectError {
98116
}
99117
}
100118

119+
impl From<io::Error> for PoolConnectError {
120+
fn from(e: io::Error) -> Self {
121+
PoolConnectError::OnConnectError {
122+
source: Arc::new(e),
123+
}
124+
}
125+
}
126+
101127
/// Error when calling a fn on the [`ConnectionPool`].
102128
///
103129
/// The only thing that can go wrong is that the connection pool is shut down.
@@ -134,10 +160,23 @@ impl Context {
134160
) {
135161
let context = self;
136162

163+
let context2 = context.clone();
164+
let conn_fut = async move {
165+
let conn = context2
166+
.endpoint
167+
.connect(node_id, &context2.alpn)
168+
.await
169+
.map_err(PoolConnectError::from)?;
170+
if let Some(on_connect) = &context2.options.on_connect {
171+
on_connect(&context2.endpoint, &conn)
172+
.await
173+
.map_err(PoolConnectError::from)?;
174+
}
175+
Result::<Connection, PoolConnectError>::Ok(conn)
176+
};
177+
137178
// Connect to the node
138-
let state = context
139-
.endpoint
140-
.connect(node_id, &context.alpn)
179+
let state = conn_fut
141180
.timeout(context.options.connect_timeout)
142181
.await
143182
.map_err(|_| PoolConnectError::Timeout)

0 commit comments

Comments
 (0)