@@ -127,8 +127,8 @@ pub use builder::NodeBuilder as Builder;
127127use chain:: ChainSource ;
128128use config:: {
129129 default_user_config, may_announce_channel, ChannelConfig , Config ,
130- LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS , NODE_ANN_BCAST_INTERVAL , PEER_RECONNECTION_INTERVAL ,
131- RGS_SYNC_INTERVAL ,
130+ BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS , LDK_EVENT_HANDLER_SHUTDOWN_TIMEOUT_SECS ,
131+ NODE_ANN_BCAST_INTERVAL , PEER_RECONNECTION_INTERVAL , RGS_SYNC_INTERVAL ,
132132} ;
133133use connection:: ConnectionManager ;
134134use event:: { EventHandler , EventQueue } ;
@@ -179,6 +179,8 @@ pub struct Node {
179179 runtime : Arc < RwLock < Option < Arc < tokio:: runtime:: Runtime > > > > ,
180180 stop_sender : tokio:: sync:: watch:: Sender < ( ) > ,
181181 background_processor_task : Mutex < Option < tokio:: task:: JoinHandle < ( ) > > > ,
182+ background_tasks : Mutex < Option < tokio:: task:: JoinSet < ( ) > > > ,
183+ cancellable_background_tasks : Mutex < Option < tokio:: task:: JoinSet < ( ) > > > ,
182184 config : Arc < Config > ,
183185 wallet : Arc < Wallet > ,
184186 chain_source : Arc < ChainSource > ,
@@ -232,6 +234,10 @@ impl Node {
232234 return Err ( Error :: AlreadyRunning ) ;
233235 }
234236
237+ let mut background_tasks = tokio:: task:: JoinSet :: new ( ) ;
238+ let mut cancellable_background_tasks = tokio:: task:: JoinSet :: new ( ) ;
239+ let runtime_handle = runtime. handle ( ) ;
240+
235241 log_info ! (
236242 self . logger,
237243 "Starting up LDK Node with node ID {} on network: {}" ,
@@ -258,19 +264,27 @@ impl Node {
258264 let sync_cman = Arc :: clone ( & self . channel_manager ) ;
259265 let sync_cmon = Arc :: clone ( & self . chain_monitor ) ;
260266 let sync_sweeper = Arc :: clone ( & self . output_sweeper ) ;
261- runtime. spawn ( async move {
262- chain_source
263- . continuously_sync_wallets ( stop_sync_receiver, sync_cman, sync_cmon, sync_sweeper)
264- . await ;
265- } ) ;
267+ background_tasks. spawn_on (
268+ async move {
269+ chain_source
270+ . continuously_sync_wallets (
271+ stop_sync_receiver,
272+ sync_cman,
273+ sync_cmon,
274+ sync_sweeper,
275+ )
276+ . await ;
277+ } ,
278+ runtime_handle,
279+ ) ;
266280
267281 if self . gossip_source . is_rgs ( ) {
268282 let gossip_source = Arc :: clone ( & self . gossip_source ) ;
269283 let gossip_sync_store = Arc :: clone ( & self . kv_store ) ;
270284 let gossip_sync_logger = Arc :: clone ( & self . logger ) ;
271285 let gossip_node_metrics = Arc :: clone ( & self . node_metrics ) ;
272286 let mut stop_gossip_sync = self . stop_sender . subscribe ( ) ;
273- runtime . spawn ( async move {
287+ cancellable_background_tasks . spawn_on ( async move {
274288 let mut interval = tokio:: time:: interval ( RGS_SYNC_INTERVAL ) ;
275289 loop {
276290 tokio:: select! {
@@ -311,7 +325,7 @@ impl Node {
311325 }
312326 }
313327 }
314- } ) ;
328+ } , runtime_handle ) ;
315329 }
316330
317331 if let Some ( listening_addresses) = & self . config . listening_addresses {
@@ -337,7 +351,7 @@ impl Node {
337351 bind_addrs. extend ( resolved_address) ;
338352 }
339353
340- runtime . spawn ( async move {
354+ cancellable_background_tasks . spawn_on ( async move {
341355 {
342356 let listener =
343357 tokio:: net:: TcpListener :: bind ( & * bind_addrs) . await
@@ -356,7 +370,7 @@ impl Node {
356370 _ = stop_listen. changed( ) => {
357371 log_debug!(
358372 listening_logger,
359- "Stopping listening to inbound connections." ,
373+ "Stopping listening to inbound connections."
360374 ) ;
361375 break ;
362376 }
@@ -375,7 +389,7 @@ impl Node {
375389 }
376390
377391 listening_indicator. store ( false , Ordering :: Release ) ;
378- } ) ;
392+ } , runtime_handle ) ;
379393 }
380394
381395 // Regularly reconnect to persisted peers.
@@ -384,15 +398,15 @@ impl Node {
384398 let connect_logger = Arc :: clone ( & self . logger ) ;
385399 let connect_peer_store = Arc :: clone ( & self . peer_store ) ;
386400 let mut stop_connect = self . stop_sender . subscribe ( ) ;
387- runtime . spawn ( async move {
401+ cancellable_background_tasks . spawn_on ( async move {
388402 let mut interval = tokio:: time:: interval ( PEER_RECONNECTION_INTERVAL ) ;
389403 interval. set_missed_tick_behavior ( tokio:: time:: MissedTickBehavior :: Skip ) ;
390404 loop {
391405 tokio:: select! {
392406 _ = stop_connect. changed( ) => {
393407 log_debug!(
394408 connect_logger,
395- "Stopping reconnecting known peers." ,
409+ "Stopping reconnecting known peers."
396410 ) ;
397411 return ;
398412 }
@@ -412,7 +426,7 @@ impl Node {
412426 }
413427 }
414428 }
415- } ) ;
429+ } , runtime_handle ) ;
416430
417431 // Regularly broadcast node announcements.
418432 let bcast_cm = Arc :: clone ( & self . channel_manager ) ;
@@ -424,7 +438,7 @@ impl Node {
424438 let mut stop_bcast = self . stop_sender . subscribe ( ) ;
425439 let node_alias = self . config . node_alias . clone ( ) ;
426440 if may_announce_channel ( & self . config ) . is_ok ( ) {
427- runtime . spawn ( async move {
441+ cancellable_background_tasks . spawn_on ( async move {
428442 // We check every 30 secs whether our last broadcast is NODE_ANN_BCAST_INTERVAL away.
429443 #[ cfg( not( test) ) ]
430444 let mut interval = tokio:: time:: interval ( Duration :: from_secs ( 30 ) ) ;
@@ -495,14 +509,15 @@ impl Node {
495509 }
496510 }
497511 }
498- } ) ;
512+ } , runtime_handle ) ;
499513 }
500514
501515 let stop_tx_bcast = self . stop_sender . subscribe ( ) ;
502516 let chain_source = Arc :: clone ( & self . chain_source ) ;
503- runtime. spawn ( async move {
504- chain_source. continuously_process_broadcast_queue ( stop_tx_bcast) . await
505- } ) ;
517+ cancellable_background_tasks. spawn_on (
518+ async move { chain_source. continuously_process_broadcast_queue ( stop_tx_bcast) . await } ,
519+ runtime_handle,
520+ ) ;
506521
507522 let bump_tx_event_handler = Arc :: new ( BumpTransactionEventHandler :: new (
508523 Arc :: clone ( & self . tx_broadcaster ) ,
@@ -587,24 +602,33 @@ impl Node {
587602 let mut stop_liquidity_handler = self . stop_sender . subscribe ( ) ;
588603 let liquidity_handler = Arc :: clone ( & liquidity_source) ;
589604 let liquidity_logger = Arc :: clone ( & self . logger ) ;
590- runtime. spawn ( async move {
591- loop {
592- tokio:: select! {
593- _ = stop_liquidity_handler. changed( ) => {
594- log_debug!(
595- liquidity_logger,
596- "Stopping processing liquidity events." ,
597- ) ;
598- return ;
605+ background_tasks. spawn_on (
606+ async move {
607+ loop {
608+ tokio:: select! {
609+ _ = stop_liquidity_handler. changed( ) => {
610+ log_debug!(
611+ liquidity_logger,
612+ "Stopping processing liquidity events." ,
613+ ) ;
614+ return ;
615+ }
616+ _ = liquidity_handler. handle_next_event( ) => { }
599617 }
600- _ = liquidity_handler. handle_next_event( ) => { }
601618 }
602- }
603- } ) ;
619+ } ,
620+ runtime_handle,
621+ ) ;
604622 }
605623
606624 * runtime_lock = Some ( runtime) ;
607625
626+ debug_assert ! ( self . background_tasks. lock( ) . unwrap( ) . is_none( ) ) ;
627+ * self . background_tasks . lock ( ) . unwrap ( ) = Some ( background_tasks) ;
628+
629+ debug_assert ! ( self . cancellable_background_tasks. lock( ) . unwrap( ) . is_none( ) ) ;
630+ * self . cancellable_background_tasks . lock ( ) . unwrap ( ) = Some ( cancellable_background_tasks) ;
631+
608632 log_info ! ( self . logger, "Startup complete." ) ;
609633 Ok ( ( ) )
610634 }
@@ -635,6 +659,17 @@ impl Node {
635659 } ,
636660 }
637661
662+ // Cancel cancellable background tasks
663+ if let Some ( mut tasks) = self . cancellable_background_tasks . lock ( ) . unwrap ( ) . take ( ) {
664+ let runtime_2 = Arc :: clone ( & runtime) ;
665+ tasks. abort_all ( ) ;
666+ tokio:: task:: block_in_place ( move || {
667+ runtime_2. block_on ( async { while let Some ( _) = tasks. join_next ( ) . await { } } )
668+ } ) ;
669+ } else {
670+ debug_assert ! ( false , "Expected some cancellable background tasks" ) ;
671+ } ;
672+
638673 // Disconnect all peers.
639674 self . peer_manager . disconnect_all_peers ( ) ;
640675 log_debug ! ( self . logger, "Disconnected all network peers." ) ;
@@ -643,6 +678,46 @@ impl Node {
643678 self . chain_source . stop ( ) ;
644679 log_debug ! ( self . logger, "Stopped chain sources." ) ;
645680
681+ // Wait until non-cancellable background tasks (mod LDK's background processor) are done.
682+ let runtime_3 = Arc :: clone ( & runtime) ;
683+ if let Some ( mut tasks) = self . background_tasks . lock ( ) . unwrap ( ) . take ( ) {
684+ tokio:: task:: block_in_place ( move || {
685+ runtime_3. block_on ( async {
686+ loop {
687+ let timeout_fut = tokio:: time:: timeout (
688+ Duration :: from_secs ( BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS ) ,
689+ tasks. join_next_with_id ( ) ,
690+ ) ;
691+ match timeout_fut. await {
692+ Ok ( Some ( Ok ( ( id, _) ) ) ) => {
693+ log_trace ! ( self . logger, "Stopped background task with id {}" , id) ;
694+ } ,
695+ Ok ( Some ( Err ( e) ) ) => {
696+ tasks. abort_all ( ) ;
697+ log_trace ! ( self . logger, "Stopping background task failed: {}" , e) ;
698+ break ;
699+ } ,
700+ Ok ( None ) => {
701+ log_debug ! ( self . logger, "Stopped all background tasks" ) ;
702+ break ;
703+ } ,
704+ Err ( e) => {
705+ tasks. abort_all ( ) ;
706+ log_error ! (
707+ self . logger,
708+ "Stopping background task timed out: {}" ,
709+ e
710+ ) ;
711+ break ;
712+ } ,
713+ }
714+ }
715+ } )
716+ } ) ;
717+ } else {
718+ debug_assert ! ( false , "Expected some background tasks" ) ;
719+ } ;
720+
646721 // Wait until background processing stopped, at least until a timeout is reached.
647722 if let Some ( background_processor_task) =
648723 self . background_processor_task . lock ( ) . unwrap ( ) . take ( )
@@ -676,7 +751,9 @@ impl Node {
676751 log_error ! ( self . logger, "Stopping event handling timed out: {}" , e) ;
677752 } ,
678753 }
679- }
754+ } else {
755+ debug_assert ! ( false , "Expected a background processing task" ) ;
756+ } ;
680757
681758 #[ cfg( tokio_unstable) ]
682759 {
0 commit comments