From 09cec26fe270491b468779605f9945cdb0f4be91 Mon Sep 17 00:00:00 2001 From: MasterPtato Date: Mon, 10 Nov 2025 14:49:26 -0800 Subject: [PATCH] fix: improve sigterm handling for the entire runtime --- Cargo.lock | 3 + engine/packages/config/src/config/mod.rs | 6 +- engine/packages/config/src/config/runtime.rs | 31 +++ engine/packages/engine/src/commands/start.rs | 2 +- engine/packages/engine/src/run_config.rs | 47 ++-- engine/packages/engine/tests/common/ctx.rs | 28 ++- engine/packages/gasoline/src/worker.rs | 92 ++++---- engine/packages/guard-core/Cargo.toml | 1 + engine/packages/guard-core/src/server.rs | 56 +++-- .../pegboard/src/workflows/actor/mod.rs | 26 +-- engine/packages/runtime/src/lib.rs | 3 + engine/packages/runtime/src/term_signal.rs | 113 +++++++++ engine/packages/service-manager/Cargo.toml | 6 +- engine/packages/service-manager/src/lib.rs | 215 ++++++++++++------ engine/packages/util/src/lib.rs | 1 - engine/packages/util/src/signal.rs | 49 ---- 16 files changed, 440 insertions(+), 239 deletions(-) create mode 100644 engine/packages/config/src/config/runtime.rs create mode 100644 engine/packages/runtime/src/term_signal.rs delete mode 100644 engine/packages/util/src/signal.rs diff --git a/Cargo.lock b/Cargo.lock index da8c1dc40e..b4abc629a1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4636,6 +4636,7 @@ dependencies = [ "rivet-config", "rivet-error", "rivet-metrics", + "rivet-runtime", "rivet-util", "rustls 0.23.29", "rustls-pemfile 2.2.0", @@ -4764,10 +4765,12 @@ version = "2.0.24-rc.1" dependencies = [ "anyhow", "chrono", + "futures-util", "include_dir", "rivet-config", "rivet-metrics", "rivet-pools", + "rivet-runtime", "tokio", "tokio-cron-scheduler", "tracing", diff --git a/engine/packages/config/src/config/mod.rs b/engine/packages/config/src/config/mod.rs index 5ae4402115..439eb45b95 100644 --- a/engine/packages/config/src/config/mod.rs +++ b/engine/packages/config/src/config/mod.rs @@ -13,6 +13,7 @@ pub mod guard; pub mod logs; pub mod pegboard; pub mod pubsub; +pub mod runtime; pub mod telemetry; pub mod topology; pub mod vector; @@ -27,6 +28,7 @@ pub use guard::*; pub use logs::*; pub use pegboard::*; pub use pubsub::PubSub; +pub use runtime::*; pub use telemetry::*; pub use topology::*; pub use vector::*; @@ -102,7 +104,7 @@ pub struct Root { pub telemetry: Telemetry, #[serde(default)] - pub allow_version_rollback: bool, + pub runtime: Runtime, } impl Default for Root { @@ -121,7 +123,7 @@ impl Default for Root { clickhouse: None, vector_http: None, telemetry: Default::default(), - allow_version_rollback: false, + runtime: Default::default(), } } } diff --git a/engine/packages/config/src/config/runtime.rs b/engine/packages/config/src/config/runtime.rs new file mode 100644 index 0000000000..8b22affa86 --- /dev/null +++ b/engine/packages/config/src/config/runtime.rs @@ -0,0 +1,31 @@ +use std::time::Duration; + +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)] +pub struct Runtime { + /// Time (in seconds) to allow for the gasoline worker engine to stop gracefully after receiving SIGTERM. + /// Defaults to 30 seconds. + worker_shutdown_duration: Option, + /// Time (in seconds) to allow for guard to wait for pending requests after receiving SIGTERM. Defaults + // to 1 hour. + guard_shutdown_duration: Option, + /// Whether or not to allow running the engine when the previous version that was run is higher than + // the current version. + allow_version_rollback: Option, +} + +impl Runtime { + pub fn worker_shutdown_duration(&self) -> Duration { + Duration::from_secs(self.worker_shutdown_duration.unwrap_or(30) as u64) + } + + pub fn guard_shutdown_duration(&self) -> Duration { + Duration::from_secs(self.guard_shutdown_duration.unwrap_or(60 * 60) as u64) + } + + pub fn allow_version_rollback(&self) -> bool { + self.allow_version_rollback.unwrap_or_default() + } +} diff --git a/engine/packages/engine/src/commands/start.rs b/engine/packages/engine/src/commands/start.rs index dd6bda16f9..b6f9200c72 100644 --- a/engine/packages/engine/src/commands/start.rs +++ b/engine/packages/engine/src/commands/start.rs @@ -108,7 +108,7 @@ async fn verify_engine_version( config: &rivet_config::Config, pools: &rivet_pools::Pools, ) -> Result<()> { - if config.allow_version_rollback { + if config.runtime.allow_version_rollback() { return Ok(()); } diff --git a/engine/packages/engine/src/run_config.rs b/engine/packages/engine/src/run_config.rs index 07d23d8c48..a268183537 100644 --- a/engine/packages/engine/src/run_config.rs +++ b/engine/packages/engine/src/run_config.rs @@ -3,33 +3,50 @@ use rivet_service_manager::{RunConfigData, Service, ServiceKind}; pub fn config(_rivet_config: rivet_config::Config) -> Result { let services = vec![ - Service::new("api_peer", ServiceKind::ApiPeer, |config, pools| { - Box::pin(rivet_api_peer::start(config, pools)) - }), - Service::new("guard", ServiceKind::ApiPublic, |config, pools| { - Box::pin(rivet_guard::start(config, pools)) - }), + Service::new( + "api_peer", + ServiceKind::ApiPeer, + |config, pools| Box::pin(rivet_api_peer::start(config, pools)), + false, + ), + Service::new( + "guard", + ServiceKind::ApiPublic, + |config, pools| Box::pin(rivet_guard::start(config, pools)), + true, + ), Service::new( "workflow_worker", ServiceKind::Standalone, |config, pools| Box::pin(rivet_workflow_worker::start(config, pools)), + true, + ), + Service::new( + "bootstrap", + ServiceKind::Oneshot, + |config, pools| Box::pin(rivet_bootstrap::start(config, pools)), + false, ), - Service::new("bootstrap", ServiceKind::Oneshot, |config, pools| { - Box::pin(rivet_bootstrap::start(config, pools)) - }), Service::new( "pegboard_serverless", // There should only be one of these, since it's auto-scaling requests ServiceKind::Singleton, |config, pools| Box::pin(pegboard_serverless::start(config, pools)), + false, ), // Core services - Service::new("tracing_reconfigure", ServiceKind::Core, |config, pools| { - Box::pin(rivet_tracing_reconfigure::start(config, pools)) - }), - Service::new("cache_purge", ServiceKind::Core, |config, pools| { - Box::pin(rivet_cache_purge::start(config, pools)) - }), + Service::new( + "tracing_reconfigure", + ServiceKind::Core, + |config, pools| Box::pin(rivet_tracing_reconfigure::start(config, pools)), + false, + ), + Service::new( + "cache_purge", + ServiceKind::Core, + |config, pools| Box::pin(rivet_cache_purge::start(config, pools)), + false, + ), ]; Ok(RunConfigData { services }) diff --git a/engine/packages/engine/tests/common/ctx.rs b/engine/packages/engine/tests/common/ctx.rs index a7a51f710d..344f52f588 100644 --- a/engine/packages/engine/tests/common/ctx.rs +++ b/engine/packages/engine/tests/common/ctx.rs @@ -78,20 +78,30 @@ impl TestCtx { let pools = pools.clone(); async move { let services = vec![ - Service::new("api-peer", ServiceKind::ApiPeer, |config, pools| { - Box::pin(rivet_api_peer::start(config, pools)) - }), - Service::new("guard", ServiceKind::Standalone, |config, pools| { - Box::pin(rivet_guard::start(config, pools)) - }), + Service::new( + "api-peer", + ServiceKind::ApiPeer, + |config, pools| Box::pin(rivet_api_peer::start(config, pools)), + false, + ), + Service::new( + "guard", + ServiceKind::Standalone, + |config, pools| Box::pin(rivet_guard::start(config, pools)), + true, + ), Service::new( "workflow-worker", ServiceKind::Standalone, |config, pools| Box::pin(rivet_workflow_worker::start(config, pools)), + true, + ), + Service::new( + "bootstrap", + ServiceKind::Oneshot, + |config, pools| Box::pin(rivet_bootstrap::start(config, pools)), + false, ), - Service::new("bootstrap", ServiceKind::Oneshot, |config, pools| { - Box::pin(rivet_bootstrap::start(config, pools)) - }), ]; rivet_service_manager::start(config, pools, services).await diff --git a/engine/packages/gasoline/src/worker.rs b/engine/packages/gasoline/src/worker.rs index 1b3bde06fd..2e650cfb27 100644 --- a/engine/packages/gasoline/src/worker.rs +++ b/engine/packages/gasoline/src/worker.rs @@ -4,10 +4,11 @@ use std::{ }; use anyhow::{Context, Result}; -use futures_util::StreamExt; +use futures_util::{StreamExt, stream::FuturesUnordered}; use opentelemetry::trace::TraceContextExt; -use rivet_util::{Id, signal::TermSignal}; -use tokio::{signal::ctrl_c, sync::watch, task::JoinHandle}; +use rivet_runtime::TermSignal; +use rivet_util::Id; +use tokio::{sync::watch, task::JoinHandle}; use tracing::Instrument; use tracing_opentelemetry::OpenTelemetrySpanExt; @@ -22,8 +23,6 @@ use crate::{ pub(crate) const PING_INTERVAL: Duration = Duration::from_secs(10); /// How often to publish metrics. const METRICS_INTERVAL: Duration = Duration::from_secs(20); -/// Time to allow running workflows to shutdown after receiving a SIGINT or SIGTERM. -const SHUTDOWN_DURATION: Duration = Duration::from_secs(30); // How long the pull workflows function can take before shutting down the runtime. const PULL_WORKFLOWS_TIMEOUT: Duration = Duration::from_secs(10); @@ -62,7 +61,8 @@ impl Worker { } } - /// Polls the database periodically or wakes immediately when `Database::bump_sub` finishes + /// Polls the database periodically or wakes immediately when `Database::bump_sub` finishes. + /// Provide a shutdown_rx to allow shutting down without triggering SIGTERM. #[tracing::instrument(skip_all, fields(worker_id=%self.worker_id))] pub async fn start(mut self, mut shutdown_rx: Option>) -> Result<()> { tracing::debug!( @@ -77,8 +77,7 @@ impl Worker { let mut tick_interval = tokio::time::interval(self.db.worker_poll_interval()); tick_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - let mut term_signal = - TermSignal::new().context("failed to setup termination signal handler")?; + let mut term_signal = TermSignal::new().await; // Update ping at least once before doing anything else self.db @@ -125,12 +124,11 @@ impl Worker { break Ok(()); } } - _ = ctrl_c() => break Ok(()), _ = term_signal.recv() => break Ok(()), } if let Err(err) = self.tick(&cache).await { - // Cancel background tasks + // Cancel background tasks. We abort because these are not critical tasks. gc_handle.abort(); metrics_handle.abort(); @@ -201,7 +199,7 @@ impl Worker { .span_context() .clone(); - let handle = tokio::task::spawn( + let handle = tokio::spawn( // NOTE: No .in_current_span() because we want this to be a separate trace async move { if let Err(err) = ctx.run(current_span_ctx).await { @@ -226,7 +224,7 @@ impl Worker { let db = self.db.clone(); let worker_id = self.worker_id; - tokio::task::spawn( + tokio::spawn( async move { let mut ping_interval = tokio::time::interval(PING_INTERVAL); ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); @@ -251,7 +249,7 @@ impl Worker { let db = self.db.clone(); let worker_id = self.worker_id; - tokio::task::spawn( + tokio::spawn( async move { let mut metrics_interval = tokio::time::interval(METRICS_INTERVAL); metrics_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); @@ -270,79 +268,65 @@ impl Worker { #[tracing::instrument(skip_all)] async fn shutdown(mut self, mut term_signal: TermSignal) { - // Shutdown sequence + let shutdown_duration = self.config.runtime.worker_shutdown_duration(); + tracing::info!( - duration=?SHUTDOWN_DURATION, + duration=?shutdown_duration, remaining_workflows=?self.running_workflows.len(), "starting worker shutdown" ); - let shutdown_start = Instant::now(); - if let Err(err) = self.db.mark_worker_inactive(self.worker_id).await { tracing::error!(?err, worker_id=?self.worker_id, "failed to mark worker as inactive"); } + // Send stop signal to all running workflows for (workflow_id, wf) in &self.running_workflows { if wf.stop.send(()).is_err() { - tracing::warn!( + tracing::debug!( ?workflow_id, "stop channel closed, workflow likely already stopped" ); } } - let mut second_sigterm = false; - loop { - self.running_workflows - .retain(|_, wf| !wf.handle.is_finished()); + // Collect all workflow tasks + let mut wf_futs = self + .running_workflows + .iter_mut() + .map(|(_, wf)| &mut wf.handle) + .collect::>(); - // Shutdown complete - if self.running_workflows.is_empty() { - break; - } - - if shutdown_start.elapsed() > SHUTDOWN_DURATION { - tracing::debug!("shutdown timed out"); - break; - } + let shutdown_start = Instant::now(); + loop { + // Future will resolve once all workflow tasks complete + let join_fut = async { while let Some(_) = wf_futs.next().await {} }; tokio::select! { - _ = ctrl_c() => { - if second_sigterm { - tracing::warn!("received third SIGTERM, aborting shutdown"); - break; - } - - tracing::warn!("received second SIGTERM"); - second_sigterm = true; - - continue; + _ = join_fut => { + break; } - _ = term_signal.recv() => { - if second_sigterm { - tracing::warn!("received third SIGTERM, aborting shutdown"); + abort = term_signal.recv() => { + if abort { + tracing::warn!("aborting worker shutdown"); break; } - - tracing::warn!("received second SIGTERM"); - second_sigterm = true; - - continue; } - _ = tokio::time::sleep(Duration::from_secs(2)) => {} + _ = tokio::time::sleep(shutdown_duration.saturating_sub(shutdown_start.elapsed())) => { + tracing::warn!("worker shutdown timed out"); + break; + } } } - if self.running_workflows.is_empty() { + let remaining_workflows = wf_futs.into_iter().count(); + if remaining_workflows == 0 { tracing::info!("all workflows evicted"); } else { tracing::warn!(remaining_workflows=?self.running_workflows.len(), "not all workflows evicted"); } - tracing::info!("shutdown complete"); - - rivet_runtime::shutdown().await; + tracing::info!("worker shutdown complete"); } } diff --git a/engine/packages/guard-core/Cargo.toml b/engine/packages/guard-core/Cargo.toml index e71adf9af1..c5185cf29d 100644 --- a/engine/packages/guard-core/Cargo.toml +++ b/engine/packages/guard-core/Cargo.toml @@ -32,6 +32,7 @@ rivet-api-builder.workspace = true rivet-config.workspace = true rivet-error.workspace = true rivet-metrics.workspace = true +rivet-runtime.workspace = true rivet-util.workspace = true rustls.workspace = true rustls-pemfile.workspace = true diff --git a/engine/packages/guard-core/src/server.rs b/engine/packages/guard-core/src/server.rs index 78c74969f4..cb05688dcd 100644 --- a/engine/packages/guard-core/src/server.rs +++ b/engine/packages/guard-core/src/server.rs @@ -7,9 +7,10 @@ use std::{ use crate::cert_resolver::{CertResolverFn, create_tls_config}; use crate::metrics; use crate::proxy_service::{CacheKeyFn, MiddlewareFn, ProxyServiceFactory, RoutingFn}; -use anyhow::*; +use anyhow::Result; +use futures_util::FutureExt; use hyper::service::service_fn; -use rivet_util::signal::TermSignal; +use rivet_runtime::TermSignal; use tokio_rustls::TlsAcceptor; use tracing::Instrument; @@ -76,7 +77,7 @@ pub async fn run_server( let graceful = hyper_util::server::graceful::GracefulShutdown::new(); // Set up signal handling for graceful shutdown - let mut term_signal = TermSignal::new()?; + let mut term_signal = TermSignal::new().await; tracing::info!("HTTP server listening on {}", http_addr); if let Some(addr) = &https_addr { @@ -129,7 +130,7 @@ pub async fn run_server( // Accept connections until we receive a shutdown signal loop { - let result: Result<()> = tokio::select! { + let res = tokio::select! { conn = http_listener.accept() => { match conn { Result::Ok((tcp_stream, remote_addr)) => { @@ -143,12 +144,12 @@ pub async fn run_server( ); }, Err(err) => { - tracing::debug!(?err, "Accept error on HTTP port"); + tracing::debug!(?err, "accept error on HTTP port"); tokio::time::sleep(Duration::from_secs(1)).await; } } Ok(()) - }, + } conn = async { match &https_listener { Some(listener) => Some(listener.accept().await), @@ -160,7 +161,7 @@ pub async fn run_server( } => { if let Some(conn) = conn { match conn { - Result::Ok((tcp_stream, remote_addr)) => { + Ok((tcp_stream, remote_addr)) => { if let Some(factory) = &https_factory { // Check if we have a TLS acceptor if let Some(acceptor) = &https_acceptor { @@ -230,34 +231,49 @@ pub async fn run_server( } }, Err(err) => { - tracing::debug!(?err, "Accept error on HTTPS port"); + tracing::debug!(?err, "accept error on HTTPS port"); tokio::time::sleep(Duration::from_secs(1)).await; } } } - Ok(()) - }, + anyhow::Ok(()) + } _ = term_signal.recv() => { - tracing::info!("Termination signal received, starting shutdown"); break; } }; - if let Err(err) = result { - tracing::error!(?err, "Error in server loop"); + if let Err(err) = res { + tracing::error!(?err, "error in guard server loop"); } } - // Start graceful shutdown with timeout - tokio::select! { - _ = graceful.shutdown() => { - tracing::info!("Graceful shutdown completed"); - }, - _ = tokio::time::sleep(Duration::from_secs(30)) => { - tracing::error!("Waited 30 seconds for graceful shutdown, aborting..."); + let shutdown_duration = config.runtime.guard_shutdown_duration(); + tracing::info!(duration=?shutdown_duration, "starting guard shutdown"); + + let mut graceful_fut = async move { graceful.shutdown().await }.boxed(); + let shutdown_start = Instant::now(); + loop { + tokio::select! { + _ = &mut graceful_fut => { + tracing::info!("all guard requests completed"); + break; + } + abort = term_signal.recv() => { + if abort { + tracing::warn!("aborting guard shutdown"); + break; + } + } + _ = tokio::time::sleep(shutdown_duration.saturating_sub(shutdown_start.elapsed())) => { + tracing::warn!("guard shutdown timed out before all requests completed"); + break; + } } } + tracing::info!("guard shutdown complete"); + Ok(()) } diff --git a/engine/packages/pegboard/src/workflows/actor/mod.rs b/engine/packages/pegboard/src/workflows/actor/mod.rs index a1f05e6a36..4f555e3a5d 100644 --- a/engine/packages/pegboard/src/workflows/actor/mod.rs +++ b/engine/packages/pegboard/src/workflows/actor/mod.rs @@ -533,6 +533,9 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> ) .await?; + // At this point, the actor is not allocated so no cleanup related to alloc idx/desired slots needs to be + // done. + ctx.workflow(destroy::Input { namespace_id: input.namespace_id, actor_id: input.actor_id, @@ -543,26 +546,6 @@ pub async fn pegboard_actor(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> .output() .await?; - // NOTE: The reason we allocate other actors from this actor workflow is because if we instead sent a - // signal to the runner wf here it would incur a heavy throughput hit and we need the runner wf to be as - // lightweight as possible; processing as few signals that aren't events/commands. - // Allocate other pending actors from queue - let res = ctx - .activity(AllocatePendingActorsInput { - namespace_id: input.namespace_id, - name: input.runner_name_selector.clone(), - }) - .await?; - - // Dispatch pending allocs - for alloc in res.allocations { - ctx.signal(alloc.signal) - .to_workflow::() - .tag("actor_id", alloc.actor_id) - .send() - .await?; - } - Ok(()) } @@ -611,6 +594,9 @@ async fn handle_stopped( }) .await?; + // NOTE: The reason we allocate other actors from this actor workflow is because if we instead sent a + // signal to the runner wf here it would incur a heavy throughput hit and we need the runner wf to be as + // lightweight as possible; processing as few signals that aren't events/commands. // Allocate other pending actors from queue since a slot has now cleared let allocate_pending_res = ctx .activity(AllocatePendingActorsInput { diff --git a/engine/packages/runtime/src/lib.rs b/engine/packages/runtime/src/lib.rs index 35f6090454..40113d50d8 100644 --- a/engine/packages/runtime/src/lib.rs +++ b/engine/packages/runtime/src/lib.rs @@ -5,6 +5,9 @@ use tokio::sync::{Notify, OnceCell}; mod metrics; mod traces; +mod term_signal; + +pub use term_signal::TermSignal; pub use traces::reload_log_filter; static SHUTDOWN: OnceCell> = OnceCell::const_new(); diff --git a/engine/packages/runtime/src/term_signal.rs b/engine/packages/runtime/src/term_signal.rs new file mode 100644 index 0000000000..b6ef0470bb --- /dev/null +++ b/engine/packages/runtime/src/term_signal.rs @@ -0,0 +1,113 @@ +use anyhow::Result; + +use tokio::{ + sync::{OnceCell, watch}, + task::JoinHandle, +}; + +#[cfg(unix)] +use tokio::signal::unix::{Signal, SignalKind, signal}; + +#[cfg(windows)] +use tokio::signal::windows::ctrl_c as windows_ctrl_c; + +const FORCE_CLOSE_THRESHOLD: usize = 3; + +static HANDLER_CELL: OnceCell<(watch::Receiver, JoinHandle<()>)> = OnceCell::const_new(); + +/// Cross-platform termination signal wrapper that handles: +/// - Unix: SIGTERM and SIGINT +/// - Windows: Ctrl+C +struct TermSignalHandler { + count: usize, + tx: watch::Sender, + + #[cfg(unix)] + sigterm: Signal, + #[cfg(unix)] + sigint: Signal, + #[cfg(windows)] + ctrl_c: tokio::signal::windows::CtrlC, +} + +impl TermSignalHandler { + /// Returns existing termination signal handler or initializes it. + fn new() -> Result { + tracing::debug!("initialized termination signal handler"); + + Ok(Self { + count: 0, + tx: watch::channel(false).0, + + #[cfg(unix)] + sigterm: signal(SignalKind::terminate())?, + #[cfg(unix)] + sigint: signal(SignalKind::interrupt())?, + #[cfg(windows)] + ctrl_c: windows_ctrl_c()?, + }) + } + + async fn run(mut self) { + loop { + #[cfg(unix)] + { + tokio::select! { + _ = self.sigterm.recv() => {} + _ = self.sigint.recv() => {} + } + } + + #[cfg(windows)] + { + self.ctrl_c.recv().await; + } + + self.count += 1; + + if self.count == 1 { + tracing::info!("received SIGTERM"); + } else { + tracing::warn!(count=%self.count, "received another SIGTERM"); + } + + if self.tx.send(self.count >= FORCE_CLOSE_THRESHOLD).is_err() { + tracing::debug!("no sigterm subscribers"); + } + } + } +} + +pub struct TermSignal(watch::Receiver); + +impl TermSignal { + pub async fn new() -> Self { + let rx = HANDLER_CELL + .get_or_init(|| { + let term_signal = TermSignalHandler::new() + .expect("failed initializing termination signal handler"); + let rx = term_signal.tx.subscribe(); + + let join_handle = tokio::spawn(term_signal.run()); + + std::future::ready((rx, join_handle)) + }) + .await + .0 + .clone(); + + TermSignal(rx) + } + + /// Returns true if the user should abort any graceful attempt to shutdown and shutdown immediately. + pub async fn recv(&mut self) -> bool { + let _ = self.0.changed().await; + *self.0.borrow() + } + + pub fn stop() { + if let Some((_, join_handle)) = HANDLER_CELL.get() { + join_handle.abort(); + } + } +} diff --git a/engine/packages/service-manager/Cargo.toml b/engine/packages/service-manager/Cargo.toml index 6373e43e75..d80027f74b 100644 --- a/engine/packages/service-manager/Cargo.toml +++ b/engine/packages/service-manager/Cargo.toml @@ -8,10 +8,12 @@ edition.workspace = true [dependencies] anyhow.workspace = true chrono.workspace = true +futures-util.workspace = true include_dir.workspace = true rivet-config.workspace = true rivet-metrics.workspace = true rivet-pools.workspace = true -tokio.workspace = true +rivet-runtime.workspace = true tokio-cron-scheduler.workspace = true -tracing.workspace = true +tokio.workspace = true +tracing.workspace = true \ No newline at end of file diff --git a/engine/packages/service-manager/src/lib.rs b/engine/packages/service-manager/src/lib.rs index 0aa53e7e30..e3970f01ae 100644 --- a/engine/packages/service-manager/src/lib.rs +++ b/engine/packages/service-manager/src/lib.rs @@ -1,6 +1,15 @@ -use std::{future::Future, pin::Pin, sync::Arc, time::Duration}; - -use anyhow::*; +use std::{ + future::Future, + pin::Pin, + sync::{ + Arc, + atomic::{AtomicBool, Ordering}, + }, + time::Duration, +}; + +use anyhow::{Context, Result, ensure}; +use futures_util::{StreamExt, stream::FuturesUnordered}; #[derive(Clone)] pub struct Service { @@ -14,10 +23,16 @@ pub struct Service { + Send + Sync, >, + pub requires_graceful_shutdown: bool, } impl Service { - pub fn new(name: &'static str, kind: ServiceKind, run: F) -> Self + pub fn new( + name: &'static str, + kind: ServiceKind, + run: F, + requires_graceful_shutdown: bool, + ) -> Self where F: Fn(rivet_config::Config, rivet_pools::Pools) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, @@ -26,6 +41,7 @@ impl Service { name, kind, run: Arc::new(move |config, pools| Box::pin(run(config, pools))), + requires_graceful_shutdown, } } } @@ -123,60 +139,74 @@ pub async fn start( ) -> Result<()> { // Spawn services tracing::info!(services=?services.len(), "starting services"); - let mut join_set = tokio::task::JoinSet::new(); + let mut running_services = Vec::new(); let cron_schedule = tokio_cron_scheduler::JobScheduler::new().await?; - let mut sleep_indefinitely = false; + + let mut term_signal = rivet_runtime::TermSignal::new().await; + let shutting_down = Arc::new(AtomicBool::new(false)); + for service in services { tracing::debug!(name=%service.name, kind=?service.kind, "server starting service"); match service.kind.behavior() { ServiceBehavior::Service => { - join_set - .build_task() + let config = config.clone(); + let pools = pools.clone(); + let shutting_down = shutting_down.clone(); + let join_handle = tokio::task::Builder::new() .name(&format!("rivet::service::{}", service.name)) - .spawn({ - let config = config.clone(); - let pools = pools.clone(); - async move { - tracing::debug!(service=%service.name, "starting service"); - - loop { - match (service.run)(config.clone(), pools.clone()).await { - Result::Ok(_) => { + .spawn(async move { + tracing::debug!(service=%service.name, "starting service"); + + loop { + match (service.run)(config.clone(), pools.clone()).await { + Result::Ok(res) => { + if shutting_down.load(Ordering::SeqCst) { + tracing::info!(service=%service.name, ?res, "service exited"); + break; + } else { tracing::error!(service=%service.name, "service exited unexpectedly"); } - Err(err) => { - tracing::error!(service=%service.name, ?err, "service crashed"); + } + Err(err) => { + tracing::error!(service=%service.name, ?err, "service crashed"); + + if shutting_down.load(Ordering::SeqCst) { + break; } } + } - tokio::time::sleep(Duration::from_secs(1)).await; + tokio::time::sleep(Duration::from_secs(1)).await; - tracing::info!(service=%service.name, "restarting service"); - } + tracing::info!(service=%service.name, "restarting service"); } }) .context("failed to spawn service")?; + + running_services.push((service.requires_graceful_shutdown, join_handle)); } ServiceBehavior::Oneshot => { - join_set - .build_task() + let config = config.clone(); + let pools = pools.clone(); + let shutting_down = shutting_down.clone(); + let join_handle = tokio::task::Builder::new() .name(&format!("rivet::oneoff::{}", service.name)) - .spawn({ - let config = config.clone(); - let pools = pools.clone(); - async move { - tracing::debug!(oneoff=%service.name, "starting oneoff"); + .spawn(async move { + tracing::debug!(oneoff=%service.name, "starting oneoff"); + + loop { + match (service.run)(config.clone(), pools.clone()).await { + Result::Ok(_) => { + tracing::debug!(oneoff=%service.name, "oneoff finished"); + break; + } + Err(err) => { + tracing::error!(oneoff=%service.name, ?err, "oneoff crashed"); - loop { - match (service.run)(config.clone(), pools.clone()).await { - Result::Ok(_) => { - tracing::debug!(oneoff=%service.name, "oneoff finished"); + if shutting_down.load(Ordering::SeqCst) { break; - } - Err(err) => { - tracing::error!(oneoff=%service.name, ?err, "oneoff crashed"); - + } else { tokio::time::sleep(Duration::from_secs(1)).await; tracing::info!(oneoff=%service.name, "restarting oneoff"); @@ -186,31 +216,33 @@ pub async fn start( } }) .context("failed to spawn oneoff")?; + + running_services.push((service.requires_graceful_shutdown, join_handle)); } ServiceBehavior::Cron(cron_config) => { - sleep_indefinitely = true; - // Spawn immediate task if cron_config.run_immediately { let service = service.clone(); - join_set - .build_task() + let config = config.clone(); + let pools = pools.clone(); + let shutting_down = shutting_down.clone(); + let join_handle = tokio::task::Builder::new() .name(&format!("rivet::cron_immediate::{}", service.name)) - .spawn({ - let config = config.clone(); - let pools = pools.clone(); - async move { - tracing::debug!(cron=%service.name, "starting immediate cron"); + .spawn(async move { + tracing::debug!(cron=%service.name, "starting immediate cron"); - for attempt in 1..=8 { - match (service.run)(config.clone(), pools.clone()).await { - Result::Ok(_) => { - tracing::debug!(cron=%service.name, ?attempt, "cron finished"); - break; - } - Err(err) => { - tracing::error!(cron=%service.name, ?attempt, ?err, "cron crashed"); + for attempt in 1..=8 { + match (service.run)(config.clone(), pools.clone()).await { + Result::Ok(_) => { + tracing::debug!(cron=%service.name, ?attempt, "cron finished"); + break; + } + Err(err) => { + tracing::error!(cron=%service.name, ?attempt, ?err, "cron crashed"); + if shutting_down.load(Ordering::SeqCst) { + return; + } else { tokio::time::sleep(Duration::from_secs(1)).await; tracing::info!(cron=%service.name, ?attempt, "restarting cron"); @@ -218,14 +250,19 @@ pub async fn start( } } } + + tracing::error!(cron=%service.name, "cron failed all restart attempts"); }) .context("failed to spawn cron")?; + + running_services.push((service.requires_graceful_shutdown, join_handle)); } // Spawn cron let config = config.clone(); let pools = pools.clone(); - let service = service.clone(); + let service2 = service.clone(); + let shutting_down = shutting_down.clone(); cron_schedule .add(tokio_cron_scheduler::Job::new_async_tz( &cron_config.schedule, @@ -233,7 +270,8 @@ pub async fn start( move |notification, _| { let config = config.clone(); let pools = pools.clone(); - let service = service.clone(); + let service = service2.clone(); + let shutting_down = shutting_down.clone(); Box::pin(async move { tracing::debug!(cron=%service.name, ?notification, "running cron"); @@ -246,31 +284,76 @@ pub async fn start( Err(err) => { tracing::error!(cron=%service.name, ?attempt, ?err, "cron crashed"); - tokio::time::sleep(Duration::from_secs(1)).await; + if shutting_down.load(Ordering::SeqCst) { + return; + } else { + tokio::time::sleep(Duration::from_secs(1)).await; - tracing::info!(cron=%service.name, ?attempt, "restarting cron"); + tracing::info!(cron=%service.name, ?attempt, "restarting cron"); + } } } } + + tracing::error!(cron=%service.name, "cron failed all restart attempts"); }) }, )?) .await?; + + // Add dummy task to prevent start command from stopping if theres a cron + let join_handle = tokio::task::Builder::new() + .name(&format!("rivet::cron_dummy::{}", service.name)) + .spawn(std::future::pending()) + .context("failed creating dummy cron task")?; + running_services.push((false, join_handle)); } } } cron_schedule.start().await?; - if sleep_indefinitely { - std::future::pending().await - } else { - // Wait for services - join_set.join_all().await; + loop { + // Waits for all service tasks to complete + let join_fut = async { + let mut handle_futs = running_services + .iter_mut() + .map(|(_, handle)| handle) + .collect::>(); + + while let Some(_) = handle_futs.next().await {} + }; + + tokio::select! { + _ = join_fut => { + tracing::info!("all services finished"); + break; + } + abort = term_signal.recv() => { + shutting_down.store(true, Ordering::SeqCst); - // Exit - tracing::info!("all services finished"); + // Abort services that don't require graceful shutdown + running_services.retain(|(requires_graceful_shutdown, handle)| { + if !requires_graceful_shutdown { + handle.abort(); + } - Ok(()) + *requires_graceful_shutdown + }); + + if abort { + // Give time for services to handle final abort + tokio::time::sleep(Duration::from_millis(50)).await; + rivet_runtime::shutdown().await; + + break; + } + } + } } + + // Stops term signal handler bg task + rivet_runtime::TermSignal::stop(); + + Ok(()) } diff --git a/engine/packages/util/src/lib.rs b/engine/packages/util/src/lib.rs index 213e0be058..d69d627a2a 100644 --- a/engine/packages/util/src/lib.rs +++ b/engine/packages/util/src/lib.rs @@ -12,7 +12,6 @@ pub mod geo; pub mod math; pub mod req; pub mod serde; -pub mod signal; pub mod size; pub mod sort; pub mod timestamp; diff --git a/engine/packages/util/src/signal.rs b/engine/packages/util/src/signal.rs deleted file mode 100644 index c4d2edc65d..0000000000 --- a/engine/packages/util/src/signal.rs +++ /dev/null @@ -1,49 +0,0 @@ -use anyhow::Result; - -#[cfg(unix)] -use tokio::signal::unix::{Signal, SignalKind, signal}; - -#[cfg(windows)] -use tokio::signal::windows::ctrl_c as windows_ctrl_c; - -/// Cross-platform termination signal wrapper that handles: -/// - Unix: SIGTERM and SIGINT -/// - Windows: Ctrl+C -pub struct TermSignal { - #[cfg(unix)] - sigterm: Signal, - #[cfg(unix)] - sigint: Signal, - #[cfg(windows)] - ctrl_c: tokio::signal::windows::CtrlC, -} - -impl TermSignal { - /// Creates a new termination signal handler - pub fn new() -> Result { - Ok(Self { - #[cfg(unix)] - sigterm: signal(SignalKind::terminate())?, - #[cfg(unix)] - sigint: signal(SignalKind::interrupt())?, - #[cfg(windows)] - ctrl_c: windows_ctrl_c()?, - }) - } - - /// Waits for the next termination signal - pub async fn recv(&mut self) -> Option<()> { - #[cfg(unix)] - { - tokio::select! { - result = self.sigterm.recv() => result, - result = self.sigint.recv() => result, - } - } - - #[cfg(windows)] - { - self.ctrl_c.recv().await - } - } -}