Skip to content

Commit 3a226de

Browse files
committed
feat(gateway): websocket hibernation
1 parent fc89662 commit 3a226de

File tree

19 files changed

+354
-116
lines changed

19 files changed

+354
-116
lines changed

CLAUDE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ Key points:
125125
- Add `Serialize, Deserialize` derives for errors with metadata fields
126126
- Always return anyhow errors from failable functions
127127
- For example: `fn foo() -> Result<i64> { /* ... */ }`
128-
- Import anyhow using `use anyhow::*` instead of importing individual types
128+
- Do not glob import (`::*`) from anyhow. Instead, import individual types and traits
129129

130130
**Dependency Management**
131131
- When adding a dependency, check for a workspace dependency in Cargo.toml

engine/artifacts/errors/guard.websocket_service_hibernate.json

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

engine/artifacts/errors/guard.websocket_service_retry.json

Lines changed: 0 additions & 5 deletions
This file was deleted.

engine/packages/guard-core/src/custom_serve.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,6 @@ pub trait CustomServeTrait: Send + Sync {
3030
// Identifies the websocket across retries.
3131
unique_request_id: Uuid,
3232
) -> Result<Option<CloseFrame>>;
33+
34+
async fn handle_websocket_hibernation(&self, websocket: WebSocketHandle) -> Result<()>;
3335
}

engine/packages/guard-core/src/errors.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,12 @@ pub struct ServiceUnavailable;
9393
pub struct WebSocketServiceUnavailable;
9494

9595
#[derive(RivetError, Serialize, Deserialize)]
96-
#[error("guard", "websocket_service_retry", "WebSocket service retry.")]
97-
pub struct WebSocketServiceRetry;
96+
#[error(
97+
"guard",
98+
"websocket_service_hibernate",
99+
"Initiate WebSocket service hibernation."
100+
)]
101+
pub struct WebSocketServiceHibernate;
98102

99103
#[derive(RivetError, Serialize, Deserialize)]
100104
#[error("guard", "websocket_service_timeout", "WebSocket service timed out.")]

engine/packages/guard-core/src/proxy_service.rs

Lines changed: 82 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,7 +1828,7 @@ impl ProxyService {
18281828
);
18291829
}
18301830
ResolveRouteOutput::Response(_) => unreachable!(),
1831-
ResolveRouteOutput::CustomServe(mut handlers) => {
1831+
ResolveRouteOutput::CustomServe(mut handler) => {
18321832
tracing::debug!(%req_path, "Spawning task to handle WebSocket communication");
18331833
let mut request_context = request_context.clone();
18341834
let req_headers = req_headers.clone();
@@ -1845,7 +1845,7 @@ impl ProxyService {
18451845
.context("failed initiating websocket handle")?;
18461846

18471847
loop {
1848-
match handlers
1848+
match handler
18491849
.handle_websocket(
18501850
ws_handle.clone(),
18511851
&req_headers,
@@ -1895,18 +1895,26 @@ impl ProxyService {
18951895
Err(err) => {
18961896
tracing::debug!(?err, "websocket handler error");
18971897

1898-
// Denotes that the connection did not fail, but needs to be retried to
1899-
// resole a new target
1900-
let ws_retry = is_ws_retry(&err);
1898+
// Denotes that the connection did not fail, but the downstream has closed
1899+
let ws_hibernate = is_ws_hibernate(&err);
19011900

1902-
if ws_retry {
1901+
if ws_hibernate {
19031902
attempts = 0;
19041903
} else {
19051904
attempts += 1;
19061905
}
19071906

1908-
if attempts > max_attempts
1909-
|| (!is_retryable_ws_error(&err) && !ws_retry)
1907+
if ws_hibernate {
1908+
// After this function returns:
1909+
// - the route will be resolved again
1910+
// - the websocket will connect to the new downstream target
1911+
// - the gateway will continue reading messages from the client ws
1912+
// (starting with the message that caused the hibernation to end)
1913+
handler
1914+
.handle_websocket_hibernation(ws_handle.clone())
1915+
.await?;
1916+
} else if attempts > max_attempts
1917+
|| !is_retryable_ws_error(&err)
19101918
{
19111919
tracing::debug!(
19121920
?attempts,
@@ -1929,79 +1937,79 @@ impl ProxyService {
19291937

19301938
break;
19311939
} else {
1932-
if !ws_retry {
1933-
let backoff = ProxyService::calculate_backoff(
1934-
attempts,
1935-
initial_interval,
1936-
);
1940+
let backoff = ProxyService::calculate_backoff(
1941+
attempts,
1942+
initial_interval,
1943+
);
19371944

1938-
tracing::debug!(
1939-
?backoff,
1940-
"WebSocket attempt {attempts} failed (service unavailable)"
1941-
);
1945+
tracing::debug!(
1946+
?backoff,
1947+
"WebSocket attempt {attempts} failed (service unavailable)"
1948+
);
19421949

1943-
tokio::time::sleep(backoff).await;
1944-
}
1950+
// Apply backoff for retryable error
1951+
tokio::time::sleep(backoff).await;
1952+
}
19451953

1946-
match state
1947-
.resolve_route(
1948-
&req_host,
1949-
&req_path,
1950-
&req_method,
1951-
state.port_type.clone(),
1952-
&req_headers,
1953-
true,
1954-
)
1955-
.await
1956-
{
1957-
Ok(ResolveRouteOutput::CustomServe(new_handlers)) => {
1958-
handlers = new_handlers;
1959-
continue;
1960-
}
1961-
Ok(ResolveRouteOutput::Response(response)) => {
1962-
ws_handle
1963-
.send(to_hyper_close(Some(str_to_close_frame(
1964-
response.message.as_ref(),
1965-
))))
1966-
.await?;
1967-
1968-
// Flush to ensure close frame is sent
1969-
ws_handle.flush().await?;
1970-
1971-
// Keep TCP connection open briefly to allow client to process close
1972-
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
1973-
}
1974-
Ok(ResolveRouteOutput::Target(_)) => {
1975-
ws_handle
1976-
.send(to_hyper_close(Some(err_to_close_frame(
1977-
errors::WebSocketTargetChanged.build(),
1978-
ray_id,
1979-
))))
1980-
.await?;
1954+
// Retry route resolution
1955+
match state
1956+
.resolve_route(
1957+
&req_host,
1958+
&req_path,
1959+
&req_method,
1960+
state.port_type.clone(),
1961+
&req_headers,
1962+
true,
1963+
)
1964+
.await
1965+
{
1966+
Ok(ResolveRouteOutput::CustomServe(new_handler)) => {
1967+
handler = new_handler;
1968+
continue;
1969+
}
1970+
Ok(ResolveRouteOutput::Response(response)) => {
1971+
ws_handle
1972+
.send(to_hyper_close(Some(str_to_close_frame(
1973+
response.message.as_ref(),
1974+
))))
1975+
.await?;
1976+
1977+
// Flush to ensure close frame is sent
1978+
ws_handle.flush().await?;
1979+
1980+
// Keep TCP connection open briefly to allow client to process close
1981+
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
1982+
}
1983+
Ok(ResolveRouteOutput::Target(_)) => {
1984+
ws_handle
1985+
.send(to_hyper_close(Some(err_to_close_frame(
1986+
errors::WebSocketTargetChanged.build(),
1987+
ray_id,
1988+
))))
1989+
.await?;
19811990

1982-
// Flush to ensure close frame is sent
1983-
ws_handle.flush().await?;
1991+
// Flush to ensure close frame is sent
1992+
ws_handle.flush().await?;
19841993

1985-
// Keep TCP connection open briefly to allow client to process close
1986-
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
1994+
// Keep TCP connection open briefly to allow client to process close
1995+
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
19871996

1988-
break;
1989-
}
1990-
Err(err) => {
1991-
ws_handle
1992-
.send(to_hyper_close(Some(err_to_close_frame(
1993-
err, ray_id,
1994-
))))
1995-
.await?;
1997+
break;
1998+
}
1999+
Err(err) => {
2000+
ws_handle
2001+
.send(to_hyper_close(Some(err_to_close_frame(
2002+
err, ray_id,
2003+
))))
2004+
.await?;
19962005

1997-
// Flush to ensure close frame is sent
1998-
ws_handle.flush().await?;
2006+
// Flush to ensure close frame is sent
2007+
ws_handle.flush().await?;
19992008

2000-
// Keep TCP connection open briefly to allow client to process close
2001-
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
2009+
// Keep TCP connection open briefly to allow client to process close
2010+
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
20022011

2003-
break;
2004-
}
2012+
break;
20052013
}
20062014
}
20072015
}
@@ -2509,9 +2517,9 @@ fn is_retryable_ws_error(err: &anyhow::Error) -> bool {
25092517
}
25102518
}
25112519

2512-
fn is_ws_retry(err: &anyhow::Error) -> bool {
2520+
fn is_ws_hibernate(err: &anyhow::Error) -> bool {
25132521
if let Some(rivet_err) = err.chain().find_map(|x| x.downcast_ref::<RivetError>()) {
2514-
rivet_err.group() == "guard" && rivet_err.code() == "websocket_service_retry"
2522+
rivet_err.group() == "guard" && rivet_err.code() == "websocket_service_hibernate"
25152523
} else {
25162524
false
25172525
}

engine/packages/guard-core/src/websocket_handle.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use anyhow::*;
2-
use futures_util::{SinkExt, StreamExt};
2+
use futures_util::{SinkExt, StreamExt, stream::Peekable};
33
use hyper::upgrade::Upgraded;
44
use hyper_tungstenite::HyperWebsocket;
55
use hyper_tungstenite::tungstenite::Message as WsMessage;
@@ -8,7 +8,8 @@ use std::sync::Arc;
88
use tokio::sync::Mutex;
99
use tokio_tungstenite::WebSocketStream;
1010

11-
pub type WebSocketReceiver = futures_util::stream::SplitStream<WebSocketStream<TokioIo<Upgraded>>>;
11+
pub type WebSocketReceiver =
12+
Peekable<futures_util::stream::SplitStream<WebSocketStream<TokioIo<Upgraded>>>>;
1213

1314
pub type WebSocketSender =
1415
futures_util::stream::SplitSink<WebSocketStream<TokioIo<Upgraded>>, WsMessage>;
@@ -26,7 +27,7 @@ impl WebSocketHandle {
2627

2728
Ok(Self {
2829
ws_tx: Arc::new(Mutex::new(ws_tx)),
29-
ws_rx: Arc::new(Mutex::new(ws_rx)),
30+
ws_rx: Arc::new(Mutex::new(ws_rx.peekable())),
3031
})
3132
}
3233

engine/packages/guard-core/tests/custom_serve.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ impl CustomServeTrait for TestCustomServe {
102102

103103
Ok(None)
104104
}
105+
106+
async fn handle_websocket_hibernation(&self, _websocket: WebSocketHandle) -> Result<()> {
107+
todo!();
108+
}
105109
}
106110

107111
// Create routing function that returns CustomServe

engine/packages/guard/src/routing/api_public.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::sync::Arc;
22

3-
use anyhow::*;
3+
use anyhow::{Context, Result};
44
use async_trait::async_trait;
55
use bytes::Bytes;
66
use gas::prelude::*;
@@ -30,14 +30,14 @@ impl CustomServeTrait for ApiPublicService {
3030
let response = service
3131
.call(req)
3232
.await
33-
.map_err(|e| anyhow::anyhow!("Failed to call api-public service: {}", e))?;
33+
.context("failed to call api-public service")?;
3434

3535
// Collect the body and convert to ResponseBody
3636
let (parts, body) = response.into_parts();
3737
let collected = body
3838
.collect()
3939
.await
40-
.map_err(|e| anyhow::anyhow!("Failed to collect response body: {}", e))?;
40+
.context("failed to collect response body")?;
4141
let bytes = collected.to_bytes();
4242
let response_body = ResponseBody::Full(Full::new(bytes));
4343
let response = Response::from_parts(parts, response_body);
@@ -55,6 +55,10 @@ impl CustomServeTrait for ApiPublicService {
5555
) -> Result<Option<CloseFrame>> {
5656
bail!("api-public does not support WebSocket connections")
5757
}
58+
59+
async fn handle_websocket_hibernation(&self, _client_ws: WebSocketHandle) -> Result<()> {
60+
bail!("api-public does not support WebSocket hibernation")
61+
}
5862
}
5963

6064
/// Route requests to the api-public service

engine/packages/guard/src/routing/pegboard_gateway.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ async fn route_request_inner(
235235

236236
// Return pegboard-gateway instance with path
237237
let gateway = pegboard_gateway::PegboardGateway::new(
238+
ctx.clone(),
238239
shared_state.pegboard_gateway.clone(),
239240
runner_id,
240241
actor_id,

0 commit comments

Comments
 (0)