Skip to content

Commit bc86aa3

Browse files
committed
fix(guard): handle websocket tasks during shutdown
1 parent 7aff577 commit bc86aa3

File tree

10 files changed

+578
-517
lines changed

10 files changed

+578
-517
lines changed

engine/artifacts/openapi.json

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

engine/packages/config/src/config/runtime.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ pub struct Runtime {
99
/// Defaults to 30 seconds.
1010
worker_shutdown_duration: Option<u32>,
1111
/// Time (in seconds) to allow for guard to wait for pending requests after receiving SIGTERM. Defaults
12-
// to 1 hour.
12+
/// to 1 hour.
1313
guard_shutdown_duration: Option<u32>,
1414
/// Whether or not to allow running the engine when the previous version that was run is higher than
15-
// the current version.
15+
/// the current version.
1616
allow_version_rollback: Option<bool>,
1717
}
1818

engine/packages/gasoline/src/worker.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ impl Worker {
323323
if remaining_workflows == 0 {
324324
tracing::info!("all workflows evicted");
325325
} else {
326-
tracing::warn!(remaining_workflows=?self.running_workflows.len(), "not all workflows evicted");
326+
tracing::warn!(?remaining_workflows, "not all workflows evicted");
327327
}
328328

329329
tracing::info!("worker shutdown complete");

engine/packages/guard-core/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub mod metrics;
66
pub mod proxy_service;
77
pub mod request_context;
88
mod server;
9+
mod task_group;
910
pub mod types;
1011
pub mod websocket_handle;
1112

engine/packages/guard-core/src/proxy_service.rs

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ use crate::{
3535
custom_serve::{CustomServeTrait, HibernationResult},
3636
errors, metrics,
3737
request_context::RequestContext,
38+
task_group::TaskGroup,
3839
};
3940

4041
const X_RIVET_TARGET: HeaderName = HeaderName::from_static("x-rivet-target");
@@ -350,6 +351,7 @@ pub struct ProxyState {
350351
in_flight_counters: Cache<(Id, std::net::IpAddr), Arc<Mutex<InFlightCounter>>>,
351352
port_type: PortType,
352353
clickhouse_inserter: Option<clickhouse_inserter::ClickHouseInserterHandle>,
354+
tasks: Arc<TaskGroup>,
353355
}
354356

355357
impl ProxyState {
@@ -377,6 +379,7 @@ impl ProxyState {
377379
.build(),
378380
port_type,
379381
clickhouse_inserter,
382+
tasks: TaskGroup::new(),
380383
}
381384
}
382385

@@ -782,14 +785,6 @@ impl ProxyService {
782785
metrics::PROXY_REQUEST_PENDING.add(1, &[]);
783786
metrics::PROXY_REQUEST_TOTAL.add(1, &[]);
784787

785-
// Prepare to release in-flight counter when done
786-
let state_clone = self.state.clone();
787-
crate::defer! {
788-
tokio::spawn(async move {
789-
state_clone.release_in_flight(client_ip, &actor_id).await;
790-
}.instrument(tracing::info_span!("release_in_flight_task")));
791-
}
792-
793788
// Update request context with target info
794789
if let Some(actor_id) = actor_id {
795790
request_context.service_actor_id = Some(actor_id);
@@ -814,6 +809,15 @@ impl ProxyService {
814809

815810
metrics::PROXY_REQUEST_PENDING.add(-1, &[]);
816811

812+
// Release in-flight counter when done
813+
let state_clone = self.state.clone();
814+
tokio::spawn(
815+
async move {
816+
state_clone.release_in_flight(client_ip, &actor_id).await;
817+
}
818+
.instrument(tracing::info_span!("release_in_flight_task")),
819+
);
820+
817821
res
818822
}
819823

@@ -1254,7 +1258,7 @@ impl ProxyService {
12541258
match target {
12551259
ResolveRouteOutput::Target(mut target) => {
12561260
tracing::debug!("Spawning task to handle WebSocket communication");
1257-
tokio::spawn(
1261+
self.state.tasks.spawn(
12581262
async move {
12591263
// Set up a timeout for the entire operation
12601264
let timeout_duration = Duration::from_secs(30); // 30 seconds timeout
@@ -1837,7 +1841,7 @@ impl ProxyService {
18371841
let req_path = req_path.clone();
18381842
let req_host = req_host.clone();
18391843

1840-
tokio::spawn(
1844+
self.state.tasks.spawn(
18411845
async move {
18421846
let request_id = Uuid::new_v4();
18431847
let mut ws_hibernation_close = false;
@@ -2194,7 +2198,7 @@ impl ProxyService {
21942198
Ok((client_response, client_ws)) => {
21952199
tracing::debug!("Client WebSocket upgrade for error proxy successful");
21962200

2197-
tokio::spawn(
2201+
self.state.tasks.spawn(
21982202
async move {
21992203
let ws_handle = match WebSocketHandle::new(client_ws).await {
22002204
Ok(ws_handle) => ws_handle,
@@ -2337,11 +2341,14 @@ impl ProxyService {
23372341

23382342
// Insert analytics event asynchronously
23392343
let mut context_clone = request_context.clone();
2340-
tokio::spawn(async move {
2341-
if let Err(error) = context_clone.insert_event().await {
2342-
tracing::warn!(?error, "failed to insert guard analytics event");
2344+
tokio::spawn(
2345+
async move {
2346+
if let Err(error) = context_clone.insert_event().await {
2347+
tracing::warn!(?error, "failed to insert guard analytics event");
2348+
}
23432349
}
2344-
});
2350+
.instrument(tracing::info_span!("insert_event_task")),
2351+
);
23452352

23462353
let content_length = res
23472354
.headers()
@@ -2407,24 +2414,10 @@ impl ProxyServiceFactory {
24072414
pub fn create_service(&self, remote_addr: SocketAddr) -> ProxyService {
24082415
ProxyService::new(self.state.clone(), remote_addr)
24092416
}
2410-
}
24112417

2412-
// Helper macro for defer-like functionality
2413-
#[macro_export]
2414-
macro_rules! defer {
2415-
($($body:tt)*) => {
2416-
let _guard = {
2417-
struct Guard<F: FnOnce()>(Option<F>);
2418-
impl<F: FnOnce()> Drop for Guard<F> {
2419-
fn drop(&mut self) {
2420-
if let Some(f) = self.0.take() {
2421-
f()
2422-
}
2423-
}
2424-
}
2425-
Guard(Some(|| { $($body)* }))
2426-
};
2427-
};
2418+
pub async fn wait_idle(&self) {
2419+
self.state.tasks.wait_idle().await
2420+
}
24282421
}
24292422

24302423
fn add_proxy_headers_with_addr(

engine/packages/guard-core/src/server.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@ use std::{
44
time::{Duration, Instant},
55
};
66

7-
use crate::cert_resolver::{CertResolverFn, create_tls_config};
8-
use crate::metrics;
9-
use crate::proxy_service::{CacheKeyFn, MiddlewareFn, ProxyServiceFactory, RoutingFn};
107
use anyhow::Result;
118
use futures_util::FutureExt;
129
use hyper::service::service_fn;
1310
use rivet_runtime::TermSignal;
1411
use tokio_rustls::TlsAcceptor;
1512
use tracing::Instrument;
1613

14+
use crate::cert_resolver::{CertResolverFn, create_tls_config};
15+
use crate::metrics;
16+
use crate::proxy_service::{CacheKeyFn, MiddlewareFn, ProxyServiceFactory, RoutingFn};
17+
1718
// Start the server
1819
#[tracing::instrument(skip_all)]
1920
pub async fn run_server(
@@ -72,11 +73,8 @@ pub async fn run_server(
7273
(None, None, None, None)
7374
};
7475

75-
// Set up server builder and graceful shutdown
7676
let server = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new());
7777
let graceful = hyper_util::server::graceful::GracefulShutdown::new();
78-
79-
// Set up signal handling for graceful shutdown
8078
let mut term_signal = TermSignal::new().await;
8179

8280
tracing::info!("HTTP server listening on {}", http_addr);
@@ -252,11 +250,23 @@ pub async fn run_server(
252250
let shutdown_duration = config.runtime.guard_shutdown_duration();
253251
tracing::info!(duration=?shutdown_duration, "starting guard shutdown");
254252

255-
let mut graceful_fut = async move { graceful.shutdown().await }.boxed();
253+
let mut complete_fut = async move {
254+
// Wait until remaining requests finish
255+
graceful.shutdown().await;
256+
257+
// Wait until remaining tasks finish
258+
http_factory.wait_idle().await;
259+
260+
if let Some(https_factory) = https_factory {
261+
https_factory.wait_idle().await;
262+
}
263+
}
264+
.boxed();
265+
256266
let shutdown_start = Instant::now();
257267
loop {
258268
tokio::select! {
259-
_ = &mut graceful_fut => {
269+
_ = &mut complete_fut => {
260270
tracing::info!("all guard requests completed");
261271
break;
262272
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
use std::sync::Arc;
2+
use std::sync::atomic::{AtomicUsize, Ordering};
3+
4+
use futures::Future;
5+
use tokio::sync::Notify;
6+
use tracing::Instrument;
7+
8+
pub struct TaskGroup {
9+
running_count: AtomicUsize,
10+
notify: Notify,
11+
}
12+
13+
impl TaskGroup {
14+
pub fn new() -> Arc<Self> {
15+
Arc::new(Self {
16+
running_count: AtomicUsize::new(0),
17+
notify: Notify::new(),
18+
})
19+
}
20+
21+
pub fn spawn<F, O>(self: &Arc<Self>, fut: F)
22+
where
23+
F: Future<Output = O> + Send + 'static,
24+
{
25+
self.running_count.fetch_add(1, Ordering::Relaxed);
26+
27+
let self2 = self.clone();
28+
tokio::spawn(
29+
async move {
30+
fut.await;
31+
32+
// Decrement and notify any waiters if the count hits zero
33+
if self2.running_count.fetch_sub(1, Ordering::AcqRel) == 1 {
34+
self2.notify.notify_waiters();
35+
}
36+
}
37+
.in_current_span(),
38+
);
39+
}
40+
41+
#[tracing::instrument(skip_all)]
42+
pub async fn wait_idle(&self) {
43+
// Fast path
44+
if self.running_count.load(Ordering::Acquire) == 0 {
45+
return;
46+
}
47+
48+
// Wait for notifications until the count reaches zero
49+
loop {
50+
self.notify.notified().await;
51+
if self.running_count.load(Ordering::Acquire) == 0 {
52+
break;
53+
}
54+
}
55+
}
56+
}

engine/packages/service-manager/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ rivet-pools.workspace = true
1616
rivet-runtime.workspace = true
1717
tokio-cron-scheduler.workspace = true
1818
tokio.workspace = true
19-
tracing.workspace = true
19+
tracing.workspace = true

engine/packages/service-manager/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ pub async fn start(
160160

161161
loop {
162162
match (service.run)(config.clone(), pools.clone()).await {
163-
Result::Ok(res) => {
163+
Result::Ok(_) => {
164164
if shutting_down.load(Ordering::SeqCst) {
165-
tracing::info!(service=%service.name, ?res, "service exited");
165+
tracing::info!(service=%service.name, "service exited");
166166
break;
167167
} else {
168168
tracing::error!(service=%service.name, "service exited unexpectedly");

0 commit comments

Comments
 (0)