Skip to content

Commit 035284f

Browse files
committed
Introduce ConnectionRef instead of sending a fn.
1 parent 2ba5425 commit 035284f

File tree

2 files changed

+104
-136
lines changed

2 files changed

+104
-136
lines changed

iroh-connection-pool/src/connection_pool.rs

Lines changed: 86 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,18 @@
99
//! It is important that you use the connection only in the future passed to
1010
//! connect, and don't clone it out of the future.
1111
use std::{
12-
collections::{HashMap, VecDeque},
13-
sync::Arc,
14-
time::Duration,
12+
collections::{HashMap, VecDeque}, ops::Deref, sync::Arc, time::Duration
1513
};
1614

1715
use iroh::{
1816
Endpoint, NodeId,
1917
endpoint::{ConnectError, Connection},
2018
};
21-
use n0_future::{MaybeFuture, boxed::BoxFuture};
19+
use n0_future::MaybeFuture;
2220
use snafu::Snafu;
2321
use tokio::{
24-
sync::{mpsc, mpsc::error::SendError as TokioSendError, oneshot},
25-
task::{JoinError, JoinSet},
22+
sync::{mpsc::{self, error::SendError as TokioSendError}, oneshot, OwnedSemaphorePermit},
23+
task::JoinError,
2624
};
2725
use tokio_util::time::FutureExt;
2826
use tracing::{debug, error, trace};
@@ -45,15 +43,37 @@ impl Default for Options {
4543
}
4644
}
4745

46+
/// A reference to a connection that is owned by a connection pool.
47+
#[derive(Debug)]
48+
pub struct ConnectionRef {
49+
connection: iroh::endpoint::Connection,
50+
_permit: OwnedSemaphorePermit,
51+
}
52+
53+
impl Deref for ConnectionRef {
54+
type Target = iroh::endpoint::Connection;
55+
56+
fn deref(&self) -> &Self::Target {
57+
&self.connection
58+
}
59+
}
60+
61+
impl ConnectionRef {
62+
fn new(connection: iroh::endpoint::Connection, permit: OwnedSemaphorePermit) -> Self {
63+
Self {
64+
connection,
65+
_permit: permit,
66+
}
67+
}
68+
}
69+
4870
struct Context {
4971
options: Options,
5072
endpoint: Endpoint,
5173
owner: ConnectionPool,
5274
alpn: Vec<u8>,
5375
}
5476

55-
type BoxedHandler = Box<dyn FnOnce(PoolConnectResult) -> BoxFuture<ExecuteResult> + Send + 'static>;
56-
5777
/// Error when a connection can not be acquired
5878
///
5979
/// This includes the normal iroh connection errors as well as pool specific
@@ -87,19 +107,24 @@ impl std::fmt::Display for PoolConnectError {
87107
pub type PoolConnectResult = std::result::Result<Connection, PoolConnectError>;
88108

89109
enum ActorMessage {
90-
Handle { id: NodeId, handler: BoxedHandler },
110+
RequestRef(RequestRef),
91111
ConnectionIdle { id: NodeId },
92112
ConnectionShutdown { id: NodeId },
93113
}
94114

115+
struct RequestRef {
116+
id: NodeId,
117+
tx: oneshot::Sender<Result<ConnectionRef, PoolConnectError>>,
118+
}
119+
95120
/// Run a connection actor for a single node
96121
async fn run_connection_actor(
97122
node_id: NodeId,
98-
mut rx: mpsc::Receiver<BoxedHandler>,
123+
mut rx: mpsc::Receiver<RequestRef>,
99124
context: Arc<Context>,
100125
) {
101126
// Connect to the node
102-
let mut state = match context
127+
let state = match context
103128
.endpoint
104129
.connect(node_id, &context.alpn)
105130
.timeout(context.options.connect_timeout)
@@ -115,10 +140,10 @@ async fn run_connection_actor(
115140
return;
116141
}
117142
}
118-
119-
let mut tasks = JoinSet::new();
143+
let semaphore = Arc::new(tokio::sync::Semaphore::new(u32::MAX as usize));
120144
let idle_timer = MaybeFuture::default();
121-
tokio::pin!(idle_timer);
145+
let idle_fut = MaybeFuture::default();
146+
tokio::pin!(idle_timer, idle_fut);
122147

123148
loop {
124149
tokio::select! {
@@ -127,11 +152,26 @@ async fn run_connection_actor(
127152
// Handle new work
128153
handler = rx.recv() => {
129154
match handler {
130-
Some(handler) => {
131-
trace!(%node_id, "Received new task");
132-
// clear the idle timer
133-
idle_timer.as_mut().set_none();
134-
tasks.spawn(handler(state.clone()));
155+
Some(RequestRef { id, tx }) => {
156+
assert!(id == node_id, "Not for me!");
157+
trace!(%node_id, "Received new request");
158+
match &state {
159+
Ok(state) => {
160+
// first acquire a permit for the op, then aquire all permits for idle
161+
let permit = semaphore.clone().acquire_owned().await.expect("semaphore closed");
162+
let res = ConnectionRef::new(state.clone(), permit);
163+
if idle_fut.is_none() {
164+
idle_fut.as_mut().set_future(semaphore.clone().acquire_many_owned(u32::MAX));
165+
}
166+
167+
// clear the idle timer
168+
idle_timer.as_mut().set_none();
169+
tx.send(Ok(res)).ok();
170+
}
171+
Err(cause) => {
172+
tx.send(Err(cause.clone())).ok();
173+
}
174+
}
135175
}
136176
None => {
137177
// Channel closed - finish remaining tasks and exit
@@ -140,43 +180,15 @@ async fn run_connection_actor(
140180
}
141181
}
142182

143-
// Handle completed tasks
144-
Some(task_result) = tasks.join_next(), if !tasks.is_empty() => {
145-
match task_result {
146-
Ok(Ok(())) => {
147-
debug!(%node_id, "Task completed");
148-
}
149-
Ok(Err(e)) => {
150-
error!(%node_id, "Task failed: {}", e);
151-
if let Ok(conn) = state {
152-
conn.close(1u32.into(), b"error");
153-
}
154-
state = Err(PoolConnectError::ExecuteError(Arc::new(e)));
155-
let _ = context.owner.close(node_id).await;
156-
}
157-
Err(e) => {
158-
error!(%node_id, "Task panicked: {}", e);
159-
if let Ok(conn) = state {
160-
conn.close(1u32.into(), b"panic");
161-
}
162-
state = Err(PoolConnectError::JoinError(Arc::new(e)));
163-
let _ = context.owner.close(node_id).await;
164-
}
165-
}
166-
167-
// We are idle
168-
if tasks.is_empty() {
169-
// If the channel is closed, we can exit
170-
if rx.is_closed() {
171-
break;
172-
}
173-
if context.owner.idle(node_id).await.is_err() {
174-
// If we can't notify the pool, we are shutting down
175-
break;
176-
}
177-
// set the idle timer
178-
idle_timer.as_mut().set_future(tokio::time::sleep(context.options.idle_timeout));
183+
_ = &mut idle_fut => {
184+
// notify the pool that we are idle.
185+
trace!(%node_id, "Idle");
186+
if context.owner.idle(node_id).await.is_err() {
187+
// If we can't notify the pool, we are shutting down
188+
break;
179189
}
190+
// set the idle timer
191+
idle_timer.as_mut().set_future(tokio::time::sleep(context.options.idle_timeout));
180192
}
181193

182194
// Idle timeout - request shutdown
@@ -188,23 +200,21 @@ async fn run_connection_actor(
188200
}
189201
}
190202

191-
// Wait for remaining tasks to complete
192-
while let Some(task_result) = tasks.join_next().await {
193-
if let Err(e) = task_result {
194-
error!(%node_id, "Task failed during shutdown: {}", e);
195-
}
196-
}
197-
198-
if let Ok(connection) = &state {
199-
connection.close(0u32.into(), b"idle");
203+
if let Ok(connection) = state {
204+
let reason = if semaphore.available_permits() == u32::MAX as usize {
205+
"idle"
206+
} else {
207+
"drop"
208+
};
209+
connection.close(0u32.into(), reason.as_bytes());
200210
}
201211

202212
debug!(%node_id, "Connection actor shutting down");
203213
}
204214

205215
struct Actor {
206216
rx: mpsc::Receiver<ActorMessage>,
207-
connections: HashMap<NodeId, mpsc::Sender<BoxedHandler>>,
217+
connections: HashMap<NodeId, mpsc::Sender<RequestRef>>,
208218
context: Arc<Context>,
209219
// idle set (most recent last)
210220
// todo: use a better data structure if this becomes a performance issue
@@ -255,12 +265,13 @@ impl Actor {
255265
pub async fn run(mut self) {
256266
while let Some(msg) = self.rx.recv().await {
257267
match msg {
258-
ActorMessage::Handle { id, mut handler } => {
268+
ActorMessage::RequestRef(mut msg) => {
269+
let id = msg.id;
259270
self.remove_idle(id);
260271
// Try to send to existing connection actor
261272
if let Some(conn_tx) = self.connections.get(&id) {
262-
if let Err(TokioSendError(e)) = conn_tx.send(handler).await {
263-
handler = e;
273+
if let Err(TokioSendError(e)) = conn_tx.send(msg).await {
274+
msg = e;
264275
} else {
265276
continue;
266277
}
@@ -275,8 +286,7 @@ impl Actor {
275286
trace!("removing oldest idle connection {}", idle);
276287
self.connections.remove(&idle);
277288
} else {
278-
handler(Err(PoolConnectError::TooManyConnections))
279-
.await
289+
msg.tx.send(Err(PoolConnectError::TooManyConnections))
280290
.ok();
281291
continue;
282292
}
@@ -289,7 +299,7 @@ impl Actor {
289299
tokio::spawn(run_connection_actor(id, conn_rx, context));
290300

291301
// Send the handler to the new actor
292-
if conn_tx.send(handler).await.is_err() {
302+
if conn_tx.send(msg).await.is_err() {
293303
error!(%id, "Failed to send handler to new connection actor");
294304
self.connections.remove(&id);
295305
}
@@ -324,8 +334,6 @@ pub enum ConnectionPoolError {
324334
#[derive(Debug, Snafu)]
325335
pub struct ExecuteError;
326336

327-
type ExecuteResult = std::result::Result<(), ExecuteError>;
328-
329337
impl From<PoolConnectError> for ExecuteError {
330338
fn from(_: PoolConnectError) -> Self {
331339
ExecuteError
@@ -348,62 +356,17 @@ impl ConnectionPool {
348356
Self { tx }
349357
}
350358

351-
/// Connect to a node and execute the given handler function
352-
///
353-
/// The connection will either be a new connection or an existing one if it is already established.
354-
/// If connection establishment succeeds, the handler will be called with a [`Ok`].
355-
/// If connection establishment fails, the handler will get passed a [`Err`] containing the error.
356-
///
357-
/// The fn f is guaranteed to be called exactly once, unless the tokio runtime is shutting down.
358-
pub async fn connect<F, Fut>(
359+
pub async fn connect(
359360
&self,
360361
id: NodeId,
361-
f: F,
362-
) -> std::result::Result<(), ConnectionPoolError>
363-
where
364-
F: FnOnce(PoolConnectResult) -> Fut + Send + 'static,
365-
Fut: Future<Output = ExecuteResult> + Send + 'static,
362+
) -> std::result::Result<std::result::Result<ConnectionRef, PoolConnectError>, ConnectionPoolError>
366363
{
367-
let handler =
368-
Box::new(move |conn: PoolConnectResult| Box::pin(f(conn)) as BoxFuture<ExecuteResult>);
369-
364+
let (tx, rx) = oneshot::channel();
370365
self.tx
371-
.send(ActorMessage::Handle { id, handler })
366+
.send(ActorMessage::RequestRef(RequestRef { id, tx }))
372367
.await
373368
.map_err(|_| ConnectionPoolError::Shutdown)?;
374-
375-
Ok(())
376-
}
377-
378-
pub async fn with_connection<F, Fut, I, E>(
379-
&self,
380-
id: NodeId,
381-
f: F,
382-
) -> Result<Result<Result<I, E>, PoolConnectError>, ConnectionPoolError>
383-
where
384-
F: FnOnce(Connection) -> Fut + Send + 'static,
385-
Fut: Future<Output = Result<I, E>> + Send + 'static,
386-
I: Send + 'static,
387-
E: Send + 'static,
388-
{
389-
let (tx, rx) = oneshot::channel();
390-
self.connect(id, |conn| async move {
391-
let (res, ret) = match conn {
392-
Ok(connection) => {
393-
let res = f(connection).await;
394-
let ret = match &res {
395-
Ok(_) => Ok(()),
396-
Err(_) => Err(ExecuteError),
397-
};
398-
(Ok(res), ret)
399-
}
400-
Err(e) => (Err(e), Err(ExecuteError)),
401-
};
402-
tx.send(res).ok();
403-
ret
404-
})
405-
.await?;
406-
rx.await.map_err(|_| ConnectionPoolError::Shutdown)
369+
Ok(rx.await.map_err(|_| ConnectionPoolError::Shutdown)?)
407370
}
408371

409372
/// Close an existing connection, if it exists

0 commit comments

Comments
 (0)