99//! It is important that you use the connection only in the future passed to
1010//! connect, and don't clone it out of the future.
1111use std:: {
12- collections:: { HashMap , VecDeque } ,
13- sync:: Arc ,
14- time:: Duration ,
12+ collections:: { HashMap , VecDeque } , ops:: Deref , sync:: Arc , time:: Duration
1513} ;
1614
1715use iroh:: {
1816 Endpoint , NodeId ,
1917 endpoint:: { ConnectError , Connection } ,
2018} ;
21- use n0_future:: { MaybeFuture , boxed :: BoxFuture } ;
19+ use n0_future:: MaybeFuture ;
2220use snafu:: Snafu ;
2321use tokio:: {
24- sync:: { mpsc, mpsc :: error:: SendError as TokioSendError , oneshot} ,
25- task:: { JoinError , JoinSet } ,
22+ sync:: { mpsc:: { self , error:: SendError as TokioSendError } , oneshot, OwnedSemaphorePermit } ,
23+ task:: JoinError ,
2624} ;
2725use tokio_util:: time:: FutureExt ;
2826use tracing:: { debug, error, trace} ;
@@ -45,15 +43,37 @@ impl Default for Options {
4543 }
4644}
4745
46+ /// A reference to a connection that is owned by a connection pool.
47+ #[ derive( Debug ) ]
48+ pub struct ConnectionRef {
49+ connection : iroh:: endpoint:: Connection ,
50+ _permit : OwnedSemaphorePermit ,
51+ }
52+
53+ impl Deref for ConnectionRef {
54+ type Target = iroh:: endpoint:: Connection ;
55+
56+ fn deref ( & self ) -> & Self :: Target {
57+ & self . connection
58+ }
59+ }
60+
61+ impl ConnectionRef {
62+ fn new ( connection : iroh:: endpoint:: Connection , permit : OwnedSemaphorePermit ) -> Self {
63+ Self {
64+ connection,
65+ _permit : permit,
66+ }
67+ }
68+ }
69+
4870struct Context {
4971 options : Options ,
5072 endpoint : Endpoint ,
5173 owner : ConnectionPool ,
5274 alpn : Vec < u8 > ,
5375}
5476
55- type BoxedHandler = Box < dyn FnOnce ( PoolConnectResult ) -> BoxFuture < ExecuteResult > + Send + ' static > ;
56-
5777/// Error when a connection can not be acquired
5878///
5979/// This includes the normal iroh connection errors as well as pool specific
@@ -87,19 +107,24 @@ impl std::fmt::Display for PoolConnectError {
87107pub type PoolConnectResult = std:: result:: Result < Connection , PoolConnectError > ;
88108
89109enum ActorMessage {
90- Handle { id : NodeId , handler : BoxedHandler } ,
110+ RequestRef ( RequestRef ) ,
91111 ConnectionIdle { id : NodeId } ,
92112 ConnectionShutdown { id : NodeId } ,
93113}
94114
115+ struct RequestRef {
116+ id : NodeId ,
117+ tx : oneshot:: Sender < Result < ConnectionRef , PoolConnectError > > ,
118+ }
119+
95120/// Run a connection actor for a single node
96121async fn run_connection_actor (
97122 node_id : NodeId ,
98- mut rx : mpsc:: Receiver < BoxedHandler > ,
123+ mut rx : mpsc:: Receiver < RequestRef > ,
99124 context : Arc < Context > ,
100125) {
101126 // Connect to the node
102- let mut state = match context
127+ let state = match context
103128 . endpoint
104129 . connect ( node_id, & context. alpn )
105130 . timeout ( context. options . connect_timeout )
@@ -115,10 +140,10 @@ async fn run_connection_actor(
115140 return ;
116141 }
117142 }
118-
119- let mut tasks = JoinSet :: new ( ) ;
143+ let semaphore = Arc :: new ( tokio:: sync:: Semaphore :: new ( u32:: MAX as usize ) ) ;
120144 let idle_timer = MaybeFuture :: default ( ) ;
121- tokio:: pin!( idle_timer) ;
145+ let idle_fut = MaybeFuture :: default ( ) ;
146+ tokio:: pin!( idle_timer, idle_fut) ;
122147
123148 loop {
124149 tokio:: select! {
@@ -127,11 +152,26 @@ async fn run_connection_actor(
127152 // Handle new work
128153 handler = rx. recv( ) => {
129154 match handler {
130- Some ( handler) => {
131- trace!( %node_id, "Received new task" ) ;
132- // clear the idle timer
133- idle_timer. as_mut( ) . set_none( ) ;
134- tasks. spawn( handler( state. clone( ) ) ) ;
155+ Some ( RequestRef { id, tx } ) => {
156+ assert!( id == node_id, "Not for me!" ) ;
157+ trace!( %node_id, "Received new request" ) ;
158+ match & state {
159+ Ok ( state) => {
160+ // first acquire a permit for the op, then aquire all permits for idle
161+ let permit = semaphore. clone( ) . acquire_owned( ) . await . expect( "semaphore closed" ) ;
162+ let res = ConnectionRef :: new( state. clone( ) , permit) ;
163+ if idle_fut. is_none( ) {
164+ idle_fut. as_mut( ) . set_future( semaphore. clone( ) . acquire_many_owned( u32 :: MAX ) ) ;
165+ }
166+
167+ // clear the idle timer
168+ idle_timer. as_mut( ) . set_none( ) ;
169+ tx. send( Ok ( res) ) . ok( ) ;
170+ }
171+ Err ( cause) => {
172+ tx. send( Err ( cause. clone( ) ) ) . ok( ) ;
173+ }
174+ }
135175 }
136176 None => {
137177 // Channel closed - finish remaining tasks and exit
@@ -140,43 +180,15 @@ async fn run_connection_actor(
140180 }
141181 }
142182
143- // Handle completed tasks
144- Some ( task_result) = tasks. join_next( ) , if !tasks. is_empty( ) => {
145- match task_result {
146- Ok ( Ok ( ( ) ) ) => {
147- debug!( %node_id, "Task completed" ) ;
148- }
149- Ok ( Err ( e) ) => {
150- error!( %node_id, "Task failed: {}" , e) ;
151- if let Ok ( conn) = state {
152- conn. close( 1u32 . into( ) , b"error" ) ;
153- }
154- state = Err ( PoolConnectError :: ExecuteError ( Arc :: new( e) ) ) ;
155- let _ = context. owner. close( node_id) . await ;
156- }
157- Err ( e) => {
158- error!( %node_id, "Task panicked: {}" , e) ;
159- if let Ok ( conn) = state {
160- conn. close( 1u32 . into( ) , b"panic" ) ;
161- }
162- state = Err ( PoolConnectError :: JoinError ( Arc :: new( e) ) ) ;
163- let _ = context. owner. close( node_id) . await ;
164- }
165- }
166-
167- // We are idle
168- if tasks. is_empty( ) {
169- // If the channel is closed, we can exit
170- if rx. is_closed( ) {
171- break ;
172- }
173- if context. owner. idle( node_id) . await . is_err( ) {
174- // If we can't notify the pool, we are shutting down
175- break ;
176- }
177- // set the idle timer
178- idle_timer. as_mut( ) . set_future( tokio:: time:: sleep( context. options. idle_timeout) ) ;
183+ _ = & mut idle_fut => {
184+ // notify the pool that we are idle.
185+ trace!( %node_id, "Idle" ) ;
186+ if context. owner. idle( node_id) . await . is_err( ) {
187+ // If we can't notify the pool, we are shutting down
188+ break ;
179189 }
190+ // set the idle timer
191+ idle_timer. as_mut( ) . set_future( tokio:: time:: sleep( context. options. idle_timeout) ) ;
180192 }
181193
182194 // Idle timeout - request shutdown
@@ -188,23 +200,21 @@ async fn run_connection_actor(
188200 }
189201 }
190202
191- // Wait for remaining tasks to complete
192- while let Some ( task_result) = tasks. join_next ( ) . await {
193- if let Err ( e) = task_result {
194- error ! ( %node_id, "Task failed during shutdown: {}" , e) ;
195- }
196- }
197-
198- if let Ok ( connection) = & state {
199- connection. close ( 0u32 . into ( ) , b"idle" ) ;
203+ if let Ok ( connection) = state {
204+ let reason = if semaphore. available_permits ( ) == u32:: MAX as usize {
205+ "idle"
206+ } else {
207+ "drop"
208+ } ;
209+ connection. close ( 0u32 . into ( ) , reason. as_bytes ( ) ) ;
200210 }
201211
202212 debug ! ( %node_id, "Connection actor shutting down" ) ;
203213}
204214
205215struct Actor {
206216 rx : mpsc:: Receiver < ActorMessage > ,
207- connections : HashMap < NodeId , mpsc:: Sender < BoxedHandler > > ,
217+ connections : HashMap < NodeId , mpsc:: Sender < RequestRef > > ,
208218 context : Arc < Context > ,
209219 // idle set (most recent last)
210220 // todo: use a better data structure if this becomes a performance issue
@@ -255,12 +265,13 @@ impl Actor {
255265 pub async fn run ( mut self ) {
256266 while let Some ( msg) = self . rx . recv ( ) . await {
257267 match msg {
258- ActorMessage :: Handle { id, mut handler } => {
268+ ActorMessage :: RequestRef ( mut msg) => {
269+ let id = msg. id ;
259270 self . remove_idle ( id) ;
260271 // Try to send to existing connection actor
261272 if let Some ( conn_tx) = self . connections . get ( & id) {
262- if let Err ( TokioSendError ( e) ) = conn_tx. send ( handler ) . await {
263- handler = e;
273+ if let Err ( TokioSendError ( e) ) = conn_tx. send ( msg ) . await {
274+ msg = e;
264275 } else {
265276 continue ;
266277 }
@@ -275,8 +286,7 @@ impl Actor {
275286 trace ! ( "removing oldest idle connection {}" , idle) ;
276287 self . connections . remove ( & idle) ;
277288 } else {
278- handler ( Err ( PoolConnectError :: TooManyConnections ) )
279- . await
289+ msg. tx . send ( Err ( PoolConnectError :: TooManyConnections ) )
280290 . ok ( ) ;
281291 continue ;
282292 }
@@ -289,7 +299,7 @@ impl Actor {
289299 tokio:: spawn ( run_connection_actor ( id, conn_rx, context) ) ;
290300
291301 // Send the handler to the new actor
292- if conn_tx. send ( handler ) . await . is_err ( ) {
302+ if conn_tx. send ( msg ) . await . is_err ( ) {
293303 error ! ( %id, "Failed to send handler to new connection actor" ) ;
294304 self . connections . remove ( & id) ;
295305 }
@@ -324,8 +334,6 @@ pub enum ConnectionPoolError {
324334#[ derive( Debug , Snafu ) ]
325335pub struct ExecuteError ;
326336
327- type ExecuteResult = std:: result:: Result < ( ) , ExecuteError > ;
328-
329337impl From < PoolConnectError > for ExecuteError {
330338 fn from ( _: PoolConnectError ) -> Self {
331339 ExecuteError
@@ -348,62 +356,17 @@ impl ConnectionPool {
348356 Self { tx }
349357 }
350358
351- /// Connect to a node and execute the given handler function
352- ///
353- /// The connection will either be a new connection or an existing one if it is already established.
354- /// If connection establishment succeeds, the handler will be called with a [`Ok`].
355- /// If connection establishment fails, the handler will get passed a [`Err`] containing the error.
356- ///
357- /// The fn f is guaranteed to be called exactly once, unless the tokio runtime is shutting down.
358- pub async fn connect < F , Fut > (
359+ pub async fn connect (
359360 & self ,
360361 id : NodeId ,
361- f : F ,
362- ) -> std:: result:: Result < ( ) , ConnectionPoolError >
363- where
364- F : FnOnce ( PoolConnectResult ) -> Fut + Send + ' static ,
365- Fut : Future < Output = ExecuteResult > + Send + ' static ,
362+ ) -> std:: result:: Result < std:: result:: Result < ConnectionRef , PoolConnectError > , ConnectionPoolError >
366363 {
367- let handler =
368- Box :: new ( move |conn : PoolConnectResult | Box :: pin ( f ( conn) ) as BoxFuture < ExecuteResult > ) ;
369-
364+ let ( tx, rx) = oneshot:: channel ( ) ;
370365 self . tx
371- . send ( ActorMessage :: Handle { id, handler } )
366+ . send ( ActorMessage :: RequestRef ( RequestRef { id, tx } ) )
372367 . await
373368 . map_err ( |_| ConnectionPoolError :: Shutdown ) ?;
374-
375- Ok ( ( ) )
376- }
377-
378- pub async fn with_connection < F , Fut , I , E > (
379- & self ,
380- id : NodeId ,
381- f : F ,
382- ) -> Result < Result < Result < I , E > , PoolConnectError > , ConnectionPoolError >
383- where
384- F : FnOnce ( Connection ) -> Fut + Send + ' static ,
385- Fut : Future < Output = Result < I , E > > + Send + ' static ,
386- I : Send + ' static ,
387- E : Send + ' static ,
388- {
389- let ( tx, rx) = oneshot:: channel ( ) ;
390- self . connect ( id, |conn| async move {
391- let ( res, ret) = match conn {
392- Ok ( connection) => {
393- let res = f ( connection) . await ;
394- let ret = match & res {
395- Ok ( _) => Ok ( ( ) ) ,
396- Err ( _) => Err ( ExecuteError ) ,
397- } ;
398- ( Ok ( res) , ret)
399- }
400- Err ( e) => ( Err ( e) , Err ( ExecuteError ) ) ,
401- } ;
402- tx. send ( res) . ok ( ) ;
403- ret
404- } )
405- . await ?;
406- rx. await . map_err ( |_| ConnectionPoolError :: Shutdown )
369+ Ok ( rx. await . map_err ( |_| ConnectionPoolError :: Shutdown ) ?)
407370 }
408371
409372 /// Close an existing connection, if it exists
0 commit comments