Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions engine/packages/config/src/config/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ pub struct Runtime {
/// Defaults to 30 seconds.
worker_shutdown_duration: Option<u32>,
/// Time (in seconds) to allow for guard to wait for pending requests after receiving SIGTERM. Defaults
// to 1 hour.
/// to 1 hour.
guard_shutdown_duration: Option<u32>,
/// Whether or not to allow running the engine when the previous version that was run is higher than
// the current version.
/// the current version.
allow_version_rollback: Option<bool>,
}

Expand Down
2 changes: 1 addition & 1 deletion engine/packages/gasoline/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ impl Worker {
if remaining_workflows == 0 {
tracing::info!("all workflows evicted");
} else {
tracing::warn!(remaining_workflows=?self.running_workflows.len(), "not all workflows evicted");
tracing::warn!(?remaining_workflows, "not all workflows evicted");
}

tracing::info!("worker shutdown complete");
Expand Down
1 change: 1 addition & 0 deletions engine/packages/guard-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod metrics;
pub mod proxy_service;
pub mod request_context;
mod server;
mod task_group;
pub mod types;
pub mod websocket_handle;

Expand Down
57 changes: 25 additions & 32 deletions engine/packages/guard-core/src/proxy_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use crate::{
custom_serve::{CustomServeTrait, HibernationResult},
errors, metrics,
request_context::RequestContext,
task_group::TaskGroup,
};

const X_RIVET_TARGET: HeaderName = HeaderName::from_static("x-rivet-target");
Expand Down Expand Up @@ -350,6 +351,7 @@ pub struct ProxyState {
in_flight_counters: Cache<(Id, std::net::IpAddr), Arc<Mutex<InFlightCounter>>>,
port_type: PortType,
clickhouse_inserter: Option<clickhouse_inserter::ClickHouseInserterHandle>,
tasks: Arc<TaskGroup>,
}

impl ProxyState {
Expand Down Expand Up @@ -377,6 +379,7 @@ impl ProxyState {
.build(),
port_type,
clickhouse_inserter,
tasks: TaskGroup::new(),
}
}

Expand Down Expand Up @@ -782,14 +785,6 @@ impl ProxyService {
metrics::PROXY_REQUEST_PENDING.add(1, &[]);
metrics::PROXY_REQUEST_TOTAL.add(1, &[]);

// Prepare to release in-flight counter when done
let state_clone = self.state.clone();
crate::defer! {
tokio::spawn(async move {
state_clone.release_in_flight(client_ip, &actor_id).await;
}.instrument(tracing::info_span!("release_in_flight_task")));
}

// Update request context with target info
if let Some(actor_id) = actor_id {
request_context.service_actor_id = Some(actor_id);
Expand All @@ -814,6 +809,15 @@ impl ProxyService {

metrics::PROXY_REQUEST_PENDING.add(-1, &[]);

// Release in-flight counter when done
let state_clone = self.state.clone();
tokio::spawn(
async move {
state_clone.release_in_flight(client_ip, &actor_id).await;
}
.instrument(tracing::info_span!("release_in_flight_task")),
);

res
}

Expand Down Expand Up @@ -1254,7 +1258,7 @@ impl ProxyService {
match target {
ResolveRouteOutput::Target(mut target) => {
tracing::debug!("Spawning task to handle WebSocket communication");
tokio::spawn(
self.state.tasks.spawn(
async move {
// Set up a timeout for the entire operation
let timeout_duration = Duration::from_secs(30); // 30 seconds timeout
Expand Down Expand Up @@ -1837,7 +1841,7 @@ impl ProxyService {
let req_path = req_path.clone();
let req_host = req_host.clone();

tokio::spawn(
self.state.tasks.spawn(
async move {
let request_id = Uuid::new_v4();
let mut ws_hibernation_close = false;
Expand Down Expand Up @@ -2194,7 +2198,7 @@ impl ProxyService {
Ok((client_response, client_ws)) => {
tracing::debug!("Client WebSocket upgrade for error proxy successful");

tokio::spawn(
self.state.tasks.spawn(
async move {
let ws_handle = match WebSocketHandle::new(client_ws).await {
Ok(ws_handle) => ws_handle,
Expand Down Expand Up @@ -2337,11 +2341,14 @@ impl ProxyService {

// Insert analytics event asynchronously
let mut context_clone = request_context.clone();
tokio::spawn(async move {
if let Err(error) = context_clone.insert_event().await {
tracing::warn!(?error, "failed to insert guard analytics event");
tokio::spawn(
async move {
if let Err(error) = context_clone.insert_event().await {
tracing::warn!(?error, "failed to insert guard analytics event");
}
}
});
.instrument(tracing::info_span!("insert_event_task")),
);

let content_length = res
.headers()
Expand Down Expand Up @@ -2407,24 +2414,10 @@ impl ProxyServiceFactory {
pub fn create_service(&self, remote_addr: SocketAddr) -> ProxyService {
ProxyService::new(self.state.clone(), remote_addr)
}
}

// Helper macro for defer-like functionality
#[macro_export]
macro_rules! defer {
($($body:tt)*) => {
let _guard = {
struct Guard<F: FnOnce()>(Option<F>);
impl<F: FnOnce()> Drop for Guard<F> {
fn drop(&mut self) {
if let Some(f) = self.0.take() {
f()
}
}
}
Guard(Some(|| { $($body)* }))
};
};
pub async fn wait_idle(&self) {
self.state.tasks.wait_idle().await
}
}

fn add_proxy_headers_with_addr(
Expand Down
26 changes: 18 additions & 8 deletions engine/packages/guard-core/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ use std::{
time::{Duration, Instant},
};

use crate::cert_resolver::{CertResolverFn, create_tls_config};
use crate::metrics;
use crate::proxy_service::{CacheKeyFn, MiddlewareFn, ProxyServiceFactory, RoutingFn};
use anyhow::Result;
use futures_util::FutureExt;
use hyper::service::service_fn;
use rivet_runtime::TermSignal;
use tokio_rustls::TlsAcceptor;
use tracing::Instrument;

use crate::cert_resolver::{CertResolverFn, create_tls_config};
use crate::metrics;
use crate::proxy_service::{CacheKeyFn, MiddlewareFn, ProxyServiceFactory, RoutingFn};

// Start the server
#[tracing::instrument(skip_all)]
pub async fn run_server(
Expand Down Expand Up @@ -72,11 +73,8 @@ pub async fn run_server(
(None, None, None, None)
};

// Set up server builder and graceful shutdown
let server = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new());
let graceful = hyper_util::server::graceful::GracefulShutdown::new();

// Set up signal handling for graceful shutdown
let mut term_signal = TermSignal::new().await;

tracing::info!("HTTP server listening on {}", http_addr);
Expand Down Expand Up @@ -252,11 +250,23 @@ pub async fn run_server(
let shutdown_duration = config.runtime.guard_shutdown_duration();
tracing::info!(duration=?shutdown_duration, "starting guard shutdown");

let mut graceful_fut = async move { graceful.shutdown().await }.boxed();
let mut complete_fut = async move {
// Wait until remaining requests finish
graceful.shutdown().await;

// Wait until remaining tasks finish
http_factory.wait_idle().await;

if let Some(https_factory) = https_factory {
https_factory.wait_idle().await;
}
}
.boxed();

let shutdown_start = Instant::now();
loop {
tokio::select! {
_ = &mut graceful_fut => {
_ = &mut complete_fut => {
tracing::info!("all guard requests completed");
break;
}
Expand Down
57 changes: 57 additions & 0 deletions engine/packages/guard-core/src/task_group.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};

use futures::Future;
use tokio::sync::Notify;
use tracing::Instrument;

pub struct TaskGroup {
running_count: AtomicUsize,
notify: Notify,
}

impl TaskGroup {
pub fn new() -> Arc<Self> {
Arc::new(Self {
running_count: AtomicUsize::new(0),
notify: Notify::new(),
})
}

pub fn spawn<F, O>(self: &Arc<Self>, fut: F)
where
F: Future<Output = O> + Send + 'static,
{
self.running_count.fetch_add(1, Ordering::Relaxed);

// TODO: Handle panics
let self2 = self.clone();
tokio::spawn(
async move {
fut.await;

// Decrement and notify any waiters if the count hits zero
if self2.running_count.fetch_sub(1, Ordering::AcqRel) == 1 {
self2.notify.notify_waiters();
}
}
.in_current_span(),
);
}

#[tracing::instrument(skip_all)]
pub async fn wait_idle(&self) {
// Fast path
if self.running_count.load(Ordering::Acquire) == 0 {
return;
}

// Wait for notifications until the count reaches zero
loop {
self.notify.notified().await;
if self.running_count.load(Ordering::Acquire) == 0 {
break;
}
}
}
}
2 changes: 1 addition & 1 deletion engine/packages/service-manager/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ rivet-pools.workspace = true
rivet-runtime.workspace = true
tokio-cron-scheduler.workspace = true
tokio.workspace = true
tracing.workspace = true
tracing.workspace = true
4 changes: 2 additions & 2 deletions engine/packages/service-manager/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ pub async fn start(

loop {
match (service.run)(config.clone(), pools.clone()).await {
Result::Ok(res) => {
Result::Ok(_) => {
if shutting_down.load(Ordering::SeqCst) {
tracing::info!(service=%service.name, ?res, "service exited");
tracing::info!(service=%service.name, "service exited");
break;
} else {
tracing::error!(service=%service.name, "service exited unexpectedly");
Expand Down
Loading
Loading