@@ -19,8 +19,8 @@ use std::{collections::HashMap, sync::Arc, time::SystemTime};
1919use thiserror:: Error ;
2020use tokio:: sync:: mpsc:: { channel, Receiver , Sender } ;
2121use tokio:: sync:: Mutex ;
22- use tokio:: task:: JoinSet ;
2322use tokio:: { select, time, time:: Duration } ;
23+ use tokio_util:: task:: TaskTracker ;
2424use triggered:: { Listener , Trigger } ;
2525
2626use self :: defined_activity:: DefinedPaymentActivity ;
@@ -518,6 +518,9 @@ pub struct Simulation {
518518 activity : Vec < ActivityDefinition > ,
519519 /// Results logger that holds the simulation statistics.
520520 results : Arc < Mutex < PaymentResultLogger > > ,
521+ /// Track all tasks spawned for use in the simulation. When used in the `run` method, it will wait for
522+ /// these tasks to complete before returning.
523+ tasks : TaskTracker ,
521524 /// High level triggers used to manage simulation tasks and shutdown.
522525 shutdown_trigger : Trigger ,
523526 shutdown_listener : Listener ,
@@ -546,13 +549,15 @@ impl Simulation {
546549 cfg : SimulationCfg ,
547550 nodes : HashMap < PublicKey , Arc < Mutex < dyn LightningNode > > > ,
548551 activity : Vec < ActivityDefinition > ,
552+ tasks : TaskTracker ,
549553 ) -> Self {
550554 let ( shutdown_trigger, shutdown_listener) = triggered:: trigger ( ) ;
551555 Self {
552556 cfg,
553557 nodes,
554558 activity,
555559 results : Arc :: new ( Mutex :: new ( PaymentResultLogger :: new ( ) ) ) ,
560+ tasks,
556561 shutdown_trigger,
557562 shutdown_listener,
558563 }
@@ -644,7 +649,19 @@ impl Simulation {
644649 Ok ( ( ) )
645650 }
646651
652+ /// run until the simulation completes or we hit an error.
653+ /// Note that it will wait for the tasks in self.tasks to complete
654+ /// before returning.
647655 pub async fn run ( & self ) -> Result < ( ) , SimulationError > {
656+ self . internal_run ( ) . await ?;
657+ // Close our TaskTracker and wait for any background tasks
658+ // spawned during internal_run to complete.
659+ self . tasks . close ( ) ;
660+ self . tasks . wait ( ) . await ;
661+ Ok ( ( ) )
662+ }
663+
664+ async fn internal_run ( & self ) -> Result < ( ) , SimulationError > {
648665 if let Some ( total_time) = self . cfg . total_time {
649666 log:: info!( "Running the simulation for {}s." , total_time. as_secs( ) ) ;
650667 } else {
@@ -659,7 +676,6 @@ impl Simulation {
659676 self . activity. len( ) ,
660677 self . nodes. len( )
661678 ) ;
662- let mut tasks = JoinSet :: new ( ) ;
663679
664680 // Before we start the simulation up, start tasks that will be responsible for gathering simulation data.
665681 // The event channels are shared across our functionality:
@@ -668,21 +684,15 @@ impl Simulation {
668684 // - Event Receiver: used by data reporting to receive events that have been simulated that need to be
669685 // tracked and recorded.
670686 let ( event_sender, event_receiver) = channel ( 1 ) ;
671- self . run_data_collection ( event_receiver, & mut tasks) ;
687+ self . run_data_collection ( event_receiver, & self . tasks ) ;
672688
673689 // Get an execution kit per activity that we need to generate and spin up consumers for each source node.
674690 let activities = match self . activity_executors ( ) . await {
675691 Ok ( a) => a,
676692 Err ( e) => {
677693 // If we encounter an error while setting up the activity_executors,
678- // we need to shutdown and wait for tasks to finish. We have started background tasks in the
679- // run_data_collection function, so we should shut those down before returning.
694+ // we need to shutdown and return.
680695 self . shutdown ( ) ;
681- while let Some ( res) = tasks. join_next ( ) . await {
682- if let Err ( e) = res {
683- log:: error!( "Task exited with error: {e}." ) ;
684- }
685- }
686696 return Err ( e) ;
687697 } ,
688698 } ;
@@ -692,40 +702,30 @@ impl Simulation {
692702 . map ( |generator| generator. source_info . pubkey )
693703 . collect ( ) ,
694704 event_sender. clone ( ) ,
695- & mut tasks,
705+ & self . tasks ,
696706 ) ;
697707
698708 // Next, we'll spin up our actual producers that will be responsible for triggering the configured activity.
699- // The producers will use their own JoinSet so that the simulation can be shutdown if they all finish.
700- let mut producer_tasks = JoinSet :: new ( ) ;
709+ // The producers will use their own TaskTracker so that the simulation can be shutdown if they all finish.
710+ let producer_tasks = TaskTracker :: new ( ) ;
701711 match self
702- . dispatch_producers ( activities, consumer_channels, & mut producer_tasks)
712+ . dispatch_producers ( activities, consumer_channels, & producer_tasks)
703713 . await
704714 {
705715 Ok ( _) => { } ,
706716 Err ( e) => {
707- // If we encounter an error in dispatch_producers, we need to shutdown and wait for tasks to finish.
708- // We have started background tasks in the run_data_collection function,
709- // so we should shut those down before returning.
717+ // If we encounter an error in dispatch_producers, we need to shutdown and return.
710718 self . shutdown ( ) ;
711- while let Some ( res) = tasks. join_next ( ) . await {
712- if let Err ( e) = res {
713- log:: error!( "Task exited with error: {e}." ) ;
714- }
715- }
716719 return Err ( e) ;
717720 } ,
718721 }
719722
720723 // Start a task that waits for the producers to finish.
721724 // If all producers finish, then there is nothing left to do and the simulation can be shutdown.
722725 let producer_trigger = self . shutdown_trigger . clone ( ) ;
723- tasks. spawn ( async move {
724- while let Some ( res) = producer_tasks. join_next ( ) . await {
725- if let Err ( e) = res {
726- log:: error!( "Producer exited with error: {e}." ) ;
727- }
728- }
726+ self . tasks . spawn ( async move {
727+ producer_tasks. close ( ) ;
728+ producer_tasks. wait ( ) . await ;
729729 log:: info!( "All producers finished. Shutting down." ) ;
730730 producer_trigger. trigger ( )
731731 } ) ;
@@ -735,7 +735,7 @@ impl Simulation {
735735 let t = self . shutdown_trigger . clone ( ) ;
736736 let l = self . shutdown_listener . clone ( ) ;
737737
738- tasks. spawn ( async move {
738+ self . tasks . spawn ( async move {
739739 if time:: timeout ( total_time, l) . await . is_err ( ) {
740740 log:: info!(
741741 "Simulation run for {}s. Shutting down." ,
@@ -746,18 +746,7 @@ impl Simulation {
746746 } ) ;
747747 }
748748
749- // We always want to wait for all threads to exit, so we wait for all of them to exit and track any errors
750- // that surface. It's okay if there are multiple and one is overwritten, we just want to know whether we
751- // exited with an error or not.
752- let mut success = true ;
753- while let Some ( res) = tasks. join_next ( ) . await {
754- if let Err ( e) = res {
755- log:: error!( "Task exited with error: {e}." ) ;
756- success = false ;
757- }
758- }
759-
760- success. then_some ( ( ) ) . ok_or ( SimulationError :: TaskError )
749+ Ok ( ( ) )
761750 }
762751
763752 pub fn shutdown ( & self ) {
@@ -777,7 +766,7 @@ impl Simulation {
777766 fn run_data_collection (
778767 & self ,
779768 output_receiver : Receiver < SimulationOutput > ,
780- tasks : & mut JoinSet < ( ) > ,
769+ tasks : & TaskTracker ,
781770 ) {
782771 let listener = self . shutdown_listener . clone ( ) ;
783772 let shutdown = self . shutdown_trigger . clone ( ) ;
@@ -790,11 +779,17 @@ impl Simulation {
790779 // psr: produce simulation results
791780 let psr_listener = listener. clone ( ) ;
792781 let psr_shutdown = shutdown. clone ( ) ;
782+ let psr_tasks = tasks. clone ( ) ;
793783 tasks. spawn ( async move {
794784 log:: debug!( "Starting simulation results producer." ) ;
795- if let Err ( e) =
796- produce_simulation_results ( nodes, output_receiver, results_sender, psr_listener)
797- . await
785+ if let Err ( e) = produce_simulation_results (
786+ nodes,
787+ output_receiver,
788+ results_sender,
789+ psr_listener,
790+ & psr_tasks,
791+ )
792+ . await
798793 {
799794 psr_shutdown. trigger ( ) ;
800795 log:: error!( "Produce simulation results exited with error: {e:?}." ) ;
@@ -939,7 +934,7 @@ impl Simulation {
939934 & self ,
940935 consuming_nodes : HashSet < PublicKey > ,
941936 output_sender : Sender < SimulationOutput > ,
942- tasks : & mut JoinSet < ( ) > ,
937+ tasks : & TaskTracker ,
943938 ) -> HashMap < PublicKey , Sender < SimulationEvent > > {
944939 let mut channels = HashMap :: new ( ) ;
945940
@@ -984,7 +979,7 @@ impl Simulation {
984979 & self ,
985980 executors : Vec < ExecutorKit > ,
986981 producer_channels : HashMap < PublicKey , Sender < SimulationEvent > > ,
987- tasks : & mut JoinSet < ( ) > ,
982+ tasks : & TaskTracker ,
988983 ) -> Result < ( ) , SimulationError > {
989984 for executor in executors {
990985 let sender = producer_channels. get ( & executor. source_info . pubkey ) . ok_or (
@@ -1350,9 +1345,8 @@ async fn produce_simulation_results(
13501345 mut output_receiver : Receiver < SimulationOutput > ,
13511346 results : Sender < ( Payment , PaymentResult ) > ,
13521347 listener : Listener ,
1348+ tasks : & TaskTracker ,
13531349) -> Result < ( ) , SimulationError > {
1354- let mut set = tokio:: task:: JoinSet :: new ( ) ;
1355-
13561350 let result = loop {
13571351 tokio:: select! {
13581352 biased;
@@ -1365,7 +1359,7 @@ async fn produce_simulation_results(
13651359 match simulation_output{
13661360 SimulationOutput :: SendPaymentSuccess ( payment) => {
13671361 if let Some ( source_node) = nodes. get( & payment. source) {
1368- set . spawn( track_payment_result(
1362+ tasks . spawn( track_payment_result(
13691363 source_node. clone( ) , results. clone( ) , payment, listener. clone( )
13701364 ) ) ;
13711365 } else {
@@ -1396,11 +1390,6 @@ async fn produce_simulation_results(
13961390 } ;
13971391
13981392 log:: debug!( "Simulation results producer exiting." ) ;
1399- while let Some ( res) = set. join_next ( ) . await {
1400- if let Err ( e) = res {
1401- log:: error!( "Simulation results producer task exited with error: {e}." ) ;
1402- }
1403- }
14041393
14051394 result
14061395}
@@ -1476,6 +1465,7 @@ mod tests {
14761465 use std:: sync:: Arc ;
14771466 use std:: time:: Duration ;
14781467 use tokio:: sync:: Mutex ;
1468+ use tokio_util:: task:: TaskTracker ;
14791469
14801470 #[ test]
14811471 fn create_seeded_mut_rng ( ) {
@@ -1619,6 +1609,7 @@ mod tests {
16191609 crate :: SimulationCfg :: new ( Some ( 0 ) , 0 , 0.0 , None , None ) ,
16201610 clients,
16211611 vec ! [ activity_definition] ,
1612+ TaskTracker :: new ( ) ,
16221613 ) ;
16231614 assert ! ( simulation. validate_activity( ) . await . is_err( ) ) ;
16241615 }
0 commit comments