Skip to content

Commit 8cef849

Browse files
committed
feat(gateway): websocket hibernation
1 parent fc89662 commit 8cef849

File tree

17 files changed

+476
-140
lines changed

17 files changed

+476
-140
lines changed

CLAUDE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ Key points:
125125
- Add `Serialize, Deserialize` derives for errors with metadata fields
126126
- Always return anyhow errors from failable functions
127127
- For example: `fn foo() -> Result<i64> { /* ... */ }`
128-
- Import anyhow using `use anyhow::*` instead of importing individual types
128+
- Do not glob import (`::*`) from anyhow. Instead, import individual types and traits
129129

130130
**Dependency Management**
131131
- When adding a dependency, check for a workspace dependency in Cargo.toml

engine/artifacts/errors/guard.websocket_service_hibernate.json

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

engine/artifacts/errors/guard.websocket_service_retry.json

Lines changed: 0 additions & 5 deletions
This file was deleted.

engine/packages/guard-core/src/custom_serve.rs

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use anyhow::*;
1+
use anyhow::{Result, bail};
22
use async_trait::async_trait;
33
use bytes::Bytes;
44
use http_body_util::Full;
@@ -10,6 +10,11 @@ use crate::WebSocketHandle;
1010
use crate::proxy_service::ResponseBody;
1111
use crate::request_context::RequestContext;
1212

13+
pub enum HibernationResult {
14+
Continue,
15+
Close,
16+
}
17+
1318
/// Trait for custom request serving logic that can handle both HTTP and WebSocket requests
1419
#[async_trait]
1520
pub trait CustomServeTrait: Send + Sync {
@@ -23,11 +28,21 @@ pub trait CustomServeTrait: Send + Sync {
2328
/// Handle a WebSocket connection after upgrade. Supports connection retries.
2429
async fn handle_websocket(
2530
&self,
26-
websocket: WebSocketHandle,
27-
headers: &hyper::HeaderMap,
28-
path: &str,
29-
request_context: &mut RequestContext,
31+
_websocket: WebSocketHandle,
32+
_headers: &hyper::HeaderMap,
33+
_path: &str,
34+
_request_context: &mut RequestContext,
3035
// Identifies the websocket across retries.
31-
unique_request_id: Uuid,
32-
) -> Result<Option<CloseFrame>>;
36+
_unique_request_id: Uuid,
37+
) -> Result<Option<CloseFrame>> {
38+
bail!("service does not support websockets");
39+
}
40+
41+
/// Returns true if the websocket should close.
42+
async fn handle_websocket_hibernation(
43+
&self,
44+
_websocket: WebSocketHandle,
45+
) -> Result<HibernationResult> {
46+
bail!("service does not support websocket hibernation");
47+
}
3348
}

engine/packages/guard-core/src/errors.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,12 @@ pub struct ServiceUnavailable;
9393
pub struct WebSocketServiceUnavailable;
9494

9595
#[derive(RivetError, Serialize, Deserialize)]
96-
#[error("guard", "websocket_service_retry", "WebSocket service retry.")]
97-
pub struct WebSocketServiceRetry;
96+
#[error(
97+
"guard",
98+
"websocket_service_hibernate",
99+
"Initiate WebSocket service hibernation."
100+
)]
101+
pub struct WebSocketServiceHibernate;
98102

99103
#[derive(RivetError, Serialize, Deserialize)]
100104
#[error("guard", "websocket_service_timeout", "WebSocket service timed out.")]

engine/packages/guard-core/src/proxy_service.rs

Lines changed: 104 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use anyhow::{Context, Result, bail};
1+
use anyhow::{Context, Result, bail, ensure};
22
use bytes::Bytes;
33
use futures_util::{SinkExt, StreamExt};
44
use http_body_util::{BodyExt, Full};
@@ -31,7 +31,9 @@ use url::Url;
3131
use uuid::Uuid;
3232

3333
use crate::{
34-
WebSocketHandle, custom_serve::CustomServeTrait, errors, metrics,
34+
WebSocketHandle,
35+
custom_serve::{CustomServeTrait, HibernationResult},
36+
errors, metrics,
3537
request_context::RequestContext,
3638
};
3739

@@ -1828,7 +1830,7 @@ impl ProxyService {
18281830
);
18291831
}
18301832
ResolveRouteOutput::Response(_) => unreachable!(),
1831-
ResolveRouteOutput::CustomServe(mut handlers) => {
1833+
ResolveRouteOutput::CustomServe(mut handler) => {
18321834
tracing::debug!(%req_path, "Spawning task to handle WebSocket communication");
18331835
let mut request_context = request_context.clone();
18341836
let req_headers = req_headers.clone();
@@ -1838,14 +1840,15 @@ impl ProxyService {
18381840
tokio::spawn(
18391841
async move {
18401842
let request_id = Uuid::new_v4();
1843+
let mut ws_hibernation_close = false;
18411844
let mut attempts = 0u32;
18421845

18431846
let ws_handle = WebSocketHandle::new(client_ws)
18441847
.await
18451848
.context("failed initiating websocket handle")?;
18461849

18471850
loop {
1848-
match handlers
1851+
match handler
18491852
.handle_websocket(
18501853
ws_handle.clone(),
18511854
&req_headers,
@@ -1895,18 +1898,43 @@ impl ProxyService {
18951898
Err(err) => {
18961899
tracing::debug!(?err, "websocket handler error");
18971900

1898-
// Denotes that the connection did not fail, but needs to be retried to
1899-
// resole a new target
1900-
let ws_retry = is_ws_retry(&err);
1901+
// Denotes that the connection did not fail, but the downstream has closed
1902+
let ws_hibernate = is_ws_hibernate(&err);
19011903

1902-
if ws_retry {
1904+
if ws_hibernate {
19031905
attempts = 0;
19041906
} else {
19051907
attempts += 1;
19061908
}
19071909

1908-
if attempts > max_attempts
1909-
|| (!is_retryable_ws_error(&err) && !ws_retry)
1910+
if ws_hibernate {
1911+
// This should be unreachable because as soon as the actor is
1912+
// reconnected to after hibernation the gateway will consume the close
1913+
// frame from the client ws stream
1914+
ensure!(
1915+
!ws_hibernation_close,
1916+
"should not be hibernating again after receiving a close frame during hibernation"
1917+
);
1918+
1919+
// After this function returns:
1920+
// - the route will be resolved again
1921+
// - the websocket will connect to the new downstream target
1922+
// - the gateway will continue reading messages from the client ws
1923+
// (starting with the message that caused the hibernation to end)
1924+
let res = handler
1925+
.handle_websocket_hibernation(ws_handle.clone())
1926+
.await?;
1927+
1928+
// Despite receiving a close frame from the client during hibernation
1929+
// we are going to reconnect to the actor so that it knows the
1930+
// connection has closed
1931+
if let HibernationResult::Close = res {
1932+
tracing::debug!("starting hibernating websocket close");
1933+
1934+
ws_hibernation_close = true;
1935+
}
1936+
} else if attempts > max_attempts
1937+
|| !is_retryable_ws_error(&err)
19101938
{
19111939
tracing::debug!(
19121940
?attempts,
@@ -1929,79 +1957,79 @@ impl ProxyService {
19291957

19301958
break;
19311959
} else {
1932-
if !ws_retry {
1933-
let backoff = ProxyService::calculate_backoff(
1934-
attempts,
1935-
initial_interval,
1936-
);
1960+
let backoff = ProxyService::calculate_backoff(
1961+
attempts,
1962+
initial_interval,
1963+
);
19371964

1938-
tracing::debug!(
1939-
?backoff,
1940-
"WebSocket attempt {attempts} failed (service unavailable)"
1941-
);
1965+
tracing::debug!(
1966+
?backoff,
1967+
"WebSocket attempt {attempts} failed (service unavailable)"
1968+
);
19421969

1943-
tokio::time::sleep(backoff).await;
1944-
}
1970+
// Apply backoff for retryable error
1971+
tokio::time::sleep(backoff).await;
1972+
}
19451973

1946-
match state
1947-
.resolve_route(
1948-
&req_host,
1949-
&req_path,
1950-
&req_method,
1951-
state.port_type.clone(),
1952-
&req_headers,
1953-
true,
1954-
)
1955-
.await
1956-
{
1957-
Ok(ResolveRouteOutput::CustomServe(new_handlers)) => {
1958-
handlers = new_handlers;
1959-
continue;
1960-
}
1961-
Ok(ResolveRouteOutput::Response(response)) => {
1962-
ws_handle
1963-
.send(to_hyper_close(Some(str_to_close_frame(
1964-
response.message.as_ref(),
1965-
))))
1966-
.await?;
1967-
1968-
// Flush to ensure close frame is sent
1969-
ws_handle.flush().await?;
1970-
1971-
// Keep TCP connection open briefly to allow client to process close
1972-
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
1973-
}
1974-
Ok(ResolveRouteOutput::Target(_)) => {
1975-
ws_handle
1976-
.send(to_hyper_close(Some(err_to_close_frame(
1977-
errors::WebSocketTargetChanged.build(),
1978-
ray_id,
1979-
))))
1980-
.await?;
1974+
// Retry route resolution
1975+
match state
1976+
.resolve_route(
1977+
&req_host,
1978+
&req_path,
1979+
&req_method,
1980+
state.port_type.clone(),
1981+
&req_headers,
1982+
true,
1983+
)
1984+
.await
1985+
{
1986+
Ok(ResolveRouteOutput::CustomServe(new_handler)) => {
1987+
handler = new_handler;
1988+
continue;
1989+
}
1990+
Ok(ResolveRouteOutput::Response(response)) => {
1991+
ws_handle
1992+
.send(to_hyper_close(Some(str_to_close_frame(
1993+
response.message.as_ref(),
1994+
))))
1995+
.await?;
1996+
1997+
// Flush to ensure close frame is sent
1998+
ws_handle.flush().await?;
1999+
2000+
// Keep TCP connection open briefly to allow client to process close
2001+
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
2002+
}
2003+
Ok(ResolveRouteOutput::Target(_)) => {
2004+
ws_handle
2005+
.send(to_hyper_close(Some(err_to_close_frame(
2006+
errors::WebSocketTargetChanged.build(),
2007+
ray_id,
2008+
))))
2009+
.await?;
19812010

1982-
// Flush to ensure close frame is sent
1983-
ws_handle.flush().await?;
2011+
// Flush to ensure close frame is sent
2012+
ws_handle.flush().await?;
19842013

1985-
// Keep TCP connection open briefly to allow client to process close
1986-
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
2014+
// Keep TCP connection open briefly to allow client to process close
2015+
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
19872016

1988-
break;
1989-
}
1990-
Err(err) => {
1991-
ws_handle
1992-
.send(to_hyper_close(Some(err_to_close_frame(
1993-
err, ray_id,
1994-
))))
1995-
.await?;
2017+
break;
2018+
}
2019+
Err(err) => {
2020+
ws_handle
2021+
.send(to_hyper_close(Some(err_to_close_frame(
2022+
err, ray_id,
2023+
))))
2024+
.await?;
19962025

1997-
// Flush to ensure close frame is sent
1998-
ws_handle.flush().await?;
2026+
// Flush to ensure close frame is sent
2027+
ws_handle.flush().await?;
19992028

2000-
// Keep TCP connection open briefly to allow client to process close
2001-
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
2029+
// Keep TCP connection open briefly to allow client to process close
2030+
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
20022031

2003-
break;
2004-
}
2032+
break;
20052033
}
20062034
}
20072035
}
@@ -2509,9 +2537,9 @@ fn is_retryable_ws_error(err: &anyhow::Error) -> bool {
25092537
}
25102538
}
25112539

2512-
fn is_ws_retry(err: &anyhow::Error) -> bool {
2540+
fn is_ws_hibernate(err: &anyhow::Error) -> bool {
25132541
if let Some(rivet_err) = err.chain().find_map(|x| x.downcast_ref::<RivetError>()) {
2514-
rivet_err.group() == "guard" && rivet_err.code() == "websocket_service_retry"
2542+
rivet_err.group() == "guard" && rivet_err.code() == "websocket_service_hibernate"
25152543
} else {
25162544
false
25172545
}

engine/packages/guard-core/src/websocket_handle.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use anyhow::*;
2-
use futures_util::{SinkExt, StreamExt};
2+
use futures_util::{SinkExt, StreamExt, stream::Peekable};
33
use hyper::upgrade::Upgraded;
44
use hyper_tungstenite::HyperWebsocket;
55
use hyper_tungstenite::tungstenite::Message as WsMessage;
@@ -8,7 +8,8 @@ use std::sync::Arc;
88
use tokio::sync::Mutex;
99
use tokio_tungstenite::WebSocketStream;
1010

11-
pub type WebSocketReceiver = futures_util::stream::SplitStream<WebSocketStream<TokioIo<Upgraded>>>;
11+
pub type WebSocketReceiver =
12+
Peekable<futures_util::stream::SplitStream<WebSocketStream<TokioIo<Upgraded>>>>;
1213

1314
pub type WebSocketSender =
1415
futures_util::stream::SplitSink<WebSocketStream<TokioIo<Upgraded>>, WsMessage>;
@@ -26,7 +27,7 @@ impl WebSocketHandle {
2627

2728
Ok(Self {
2829
ws_tx: Arc::new(Mutex::new(ws_tx)),
29-
ws_rx: Arc::new(Mutex::new(ws_rx)),
30+
ws_rx: Arc::new(Mutex::new(ws_rx.peekable())),
3031
})
3132
}
3233

0 commit comments

Comments
 (0)