@@ -4,10 +4,11 @@ use std::{
44} ;
55
66use anyhow:: { Context , Result } ;
7- use futures_util:: StreamExt ;
7+ use futures_util:: { StreamExt , stream :: FuturesUnordered } ;
88use opentelemetry:: trace:: TraceContextExt ;
9- use rivet_util:: { Id , signal:: TermSignal } ;
10- use tokio:: { signal:: ctrl_c, sync:: watch, task:: JoinHandle } ;
9+ use rivet_runtime:: TermSignal ;
10+ use rivet_util:: Id ;
11+ use tokio:: { sync:: watch, task:: JoinHandle } ;
1112use tracing:: Instrument ;
1213use tracing_opentelemetry:: OpenTelemetrySpanExt ;
1314
@@ -22,8 +23,6 @@ use crate::{
2223pub ( crate ) const PING_INTERVAL : Duration = Duration :: from_secs ( 10 ) ;
2324/// How often to publish metrics.
2425const METRICS_INTERVAL : Duration = Duration :: from_secs ( 20 ) ;
25- /// Time to allow running workflows to shutdown after receiving a SIGINT or SIGTERM.
26- const SHUTDOWN_DURATION : Duration = Duration :: from_secs ( 30 ) ;
2726// How long the pull workflows function can take before shutting down the runtime.
2827const PULL_WORKFLOWS_TIMEOUT : Duration = Duration :: from_secs ( 10 ) ;
2928
@@ -62,7 +61,8 @@ impl Worker {
6261 }
6362 }
6463
65- /// Polls the database periodically or wakes immediately when `Database::bump_sub` finishes
64+ /// Polls the database periodically or wakes immediately when `Database::bump_sub` finishes.
65+ /// Provide a shutdown_rx to allow shutting down without triggering SIGTERM.
6666 #[ tracing:: instrument( skip_all, fields( worker_id=%self . worker_id) ) ]
6767 pub async fn start ( mut self , mut shutdown_rx : Option < watch:: Receiver < ( ) > > ) -> Result < ( ) > {
6868 tracing:: debug!(
@@ -77,8 +77,7 @@ impl Worker {
7777 let mut tick_interval = tokio:: time:: interval ( self . db . worker_poll_interval ( ) ) ;
7878 tick_interval. set_missed_tick_behavior ( tokio:: time:: MissedTickBehavior :: Skip ) ;
7979
80- let mut term_signal =
81- TermSignal :: new ( ) . context ( "failed to setup termination signal handler" ) ?;
80+ let mut term_signal = TermSignal :: new ( ) . await ;
8281
8382 // Update ping at least once before doing anything else
8483 self . db
@@ -125,12 +124,11 @@ impl Worker {
125124 break Ok ( ( ) ) ;
126125 }
127126 }
128- _ = ctrl_c( ) => break Ok ( ( ) ) ,
129127 _ = term_signal. recv( ) => break Ok ( ( ) ) ,
130128 }
131129
132130 if let Err ( err) = self . tick ( & cache) . await {
133- // Cancel background tasks
131+ // Cancel background tasks. We abort because these are not critical tasks.
134132 gc_handle. abort ( ) ;
135133 metrics_handle. abort ( ) ;
136134
@@ -201,7 +199,7 @@ impl Worker {
201199 . span_context ( )
202200 . clone ( ) ;
203201
204- let handle = tokio:: task :: spawn (
202+ let handle = tokio:: spawn (
205203 // NOTE: No .in_current_span() because we want this to be a separate trace
206204 async move {
207205 if let Err ( err) = ctx. run ( current_span_ctx) . await {
@@ -226,7 +224,7 @@ impl Worker {
226224 let db = self . db . clone ( ) ;
227225 let worker_id = self . worker_id ;
228226
229- tokio:: task :: spawn (
227+ tokio:: spawn (
230228 async move {
231229 let mut ping_interval = tokio:: time:: interval ( PING_INTERVAL ) ;
232230 ping_interval. set_missed_tick_behavior ( tokio:: time:: MissedTickBehavior :: Skip ) ;
@@ -251,7 +249,7 @@ impl Worker {
251249 let db = self . db . clone ( ) ;
252250 let worker_id = self . worker_id ;
253251
254- tokio:: task :: spawn (
252+ tokio:: spawn (
255253 async move {
256254 let mut metrics_interval = tokio:: time:: interval ( METRICS_INTERVAL ) ;
257255 metrics_interval. set_missed_tick_behavior ( tokio:: time:: MissedTickBehavior :: Skip ) ;
@@ -270,79 +268,65 @@ impl Worker {
270268
271269 #[ tracing:: instrument( skip_all) ]
272270 async fn shutdown ( mut self , mut term_signal : TermSignal ) {
273- // Shutdown sequence
271+ let shutdown_duration = self . config . runtime . worker_shutdown_duration ( ) ;
272+
274273 tracing:: info!(
275- duration=?SHUTDOWN_DURATION ,
274+ duration=?shutdown_duration ,
276275 remaining_workflows=?self . running_workflows. len( ) ,
277276 "starting worker shutdown"
278277 ) ;
279278
280- let shutdown_start = Instant :: now ( ) ;
281-
282279 if let Err ( err) = self . db . mark_worker_inactive ( self . worker_id ) . await {
283280 tracing:: error!( ?err, worker_id=?self . worker_id, "failed to mark worker as inactive" ) ;
284281 }
285282
283+ // Send stop signal to all running workflows
286284 for ( workflow_id, wf) in & self . running_workflows {
287285 if wf. stop . send ( ( ) ) . is_err ( ) {
288- tracing:: warn !(
286+ tracing:: debug !(
289287 ?workflow_id,
290288 "stop channel closed, workflow likely already stopped"
291289 ) ;
292290 }
293291 }
294292
295- let mut second_sigterm = false ;
296- loop {
297- self . running_workflows
298- . retain ( |_, wf| !wf. handle . is_finished ( ) ) ;
293+ // Collect all workflow tasks
294+ let mut wf_futs = self
295+ . running_workflows
296+ . iter_mut ( )
297+ . map ( |( _, wf) | & mut wf. handle )
298+ . collect :: < FuturesUnordered < _ > > ( ) ;
299299
300- // Shutdown complete
301- if self . running_workflows . is_empty ( ) {
302- break ;
303- }
304-
305- if shutdown_start. elapsed ( ) > SHUTDOWN_DURATION {
306- tracing:: debug!( "shutdown timed out" ) ;
307- break ;
308- }
300+ let shutdown_start = Instant :: now ( ) ;
301+ loop {
302+ // Future will resolve once all workflow tasks complete
303+ let join_fut = async { while let Some ( _) = wf_futs. next ( ) . await { } } ;
309304
310305 tokio:: select! {
311- _ = ctrl_c( ) => {
312- if second_sigterm {
313- tracing:: warn!( "received third SIGTERM, aborting shutdown" ) ;
314- break ;
315- }
316-
317- tracing:: warn!( "received second SIGTERM" ) ;
318- second_sigterm = true ;
319-
320- continue ;
306+ _ = join_fut => {
307+ break ;
321308 }
322- _ = term_signal. recv( ) => {
323- if second_sigterm {
324- tracing:: warn!( "received third SIGTERM, aborting shutdown" ) ;
309+ abort = term_signal. recv( ) => {
310+ if abort {
311+ tracing:: warn!( "aborting worker shutdown" ) ;
325312 break ;
326313 }
327-
328- tracing:: warn!( "received second SIGTERM" ) ;
329- second_sigterm = true ;
330-
331- continue ;
332314 }
333- _ = tokio:: time:: sleep( Duration :: from_secs( 2 ) ) => { }
315+ _ = tokio:: time:: sleep( shutdown_duration. saturating_sub( shutdown_start. elapsed( ) ) ) => {
316+ tracing:: warn!( "worker shutdown timed out" ) ;
317+ break ;
318+ }
334319 }
335320 }
336321
337- if self . running_workflows . is_empty ( ) {
322+ let remaining_workflows = wf_futs. into_iter ( ) . count ( ) ;
323+ if remaining_workflows == 0 {
338324 tracing:: info!( "all workflows evicted" ) ;
339325 } else {
340326 tracing:: warn!( remaining_workflows=?self . running_workflows. len( ) , "not all workflows evicted" ) ;
341327 }
342328
343- tracing:: info!( "shutdown complete" ) ;
344-
345- rivet_runtime:: shutdown ( ) . await ;
329+ tracing:: info!( "worker shutdown complete" ) ;
346330 }
347331}
348332
0 commit comments