Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ Key points:
- Add `Serialize, Deserialize` derives for errors with metadata fields
- Always return anyhow errors from failable functions
- For example: `fn foo() -> Result<i64> { /* ... */ }`
- Import anyhow using `use anyhow::*` instead of importing individual types
- Do not glob import (`::*`) from anyhow. Instead, import individual types and traits

**Dependency Management**
- When adding a dependency, check for a workspace dependency in Cargo.toml
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 0 additions & 5 deletions engine/artifacts/errors/guard.websocket_service_retry.json

This file was deleted.

29 changes: 22 additions & 7 deletions engine/packages/guard-core/src/custom_serve.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use anyhow::*;
use anyhow::{Result, bail};
use async_trait::async_trait;
use bytes::Bytes;
use http_body_util::Full;
Expand All @@ -10,6 +10,11 @@ use crate::WebSocketHandle;
use crate::proxy_service::ResponseBody;
use crate::request_context::RequestContext;

pub enum HibernationResult {
Continue,
Close,
}

/// Trait for custom request serving logic that can handle both HTTP and WebSocket requests
#[async_trait]
pub trait CustomServeTrait: Send + Sync {
Expand All @@ -23,11 +28,21 @@ pub trait CustomServeTrait: Send + Sync {
/// Handle a WebSocket connection after upgrade. Supports connection retries.
async fn handle_websocket(
&self,
websocket: WebSocketHandle,
headers: &hyper::HeaderMap,
path: &str,
request_context: &mut RequestContext,
_websocket: WebSocketHandle,
_headers: &hyper::HeaderMap,
_path: &str,
_request_context: &mut RequestContext,
// Identifies the websocket across retries.
unique_request_id: Uuid,
) -> Result<Option<CloseFrame>>;
_unique_request_id: Uuid,
) -> Result<Option<CloseFrame>> {
bail!("service does not support websockets");
}

/// Returns true if the websocket should close.
async fn handle_websocket_hibernation(
&self,
_websocket: WebSocketHandle,
) -> Result<HibernationResult> {
bail!("service does not support websocket hibernation");
}
}
8 changes: 6 additions & 2 deletions engine/packages/guard-core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,12 @@ pub struct ServiceUnavailable;
pub struct WebSocketServiceUnavailable;

#[derive(RivetError, Serialize, Deserialize)]
#[error("guard", "websocket_service_retry", "WebSocket service retry.")]
pub struct WebSocketServiceRetry;
#[error(
"guard",
"websocket_service_hibernate",
"Initiate WebSocket service hibernation."
)]
pub struct WebSocketServiceHibernate;

#[derive(RivetError, Serialize, Deserialize)]
#[error("guard", "websocket_service_timeout", "WebSocket service timed out.")]
Expand Down
180 changes: 104 additions & 76 deletions engine/packages/guard-core/src/proxy_service.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use anyhow::{Context, Result, bail};
use anyhow::{Context, Result, bail, ensure};
use bytes::Bytes;
use futures_util::{SinkExt, StreamExt};
use http_body_util::{BodyExt, Full};
Expand Down Expand Up @@ -31,7 +31,9 @@ use url::Url;
use uuid::Uuid;

use crate::{
WebSocketHandle, custom_serve::CustomServeTrait, errors, metrics,
WebSocketHandle,
custom_serve::{CustomServeTrait, HibernationResult},
errors, metrics,
request_context::RequestContext,
};

Expand Down Expand Up @@ -1828,7 +1830,7 @@ impl ProxyService {
);
}
ResolveRouteOutput::Response(_) => unreachable!(),
ResolveRouteOutput::CustomServe(mut handlers) => {
ResolveRouteOutput::CustomServe(mut handler) => {
tracing::debug!(%req_path, "Spawning task to handle WebSocket communication");
let mut request_context = request_context.clone();
let req_headers = req_headers.clone();
Expand All @@ -1838,14 +1840,15 @@ impl ProxyService {
tokio::spawn(
async move {
let request_id = Uuid::new_v4();
let mut ws_hibernation_close = false;
let mut attempts = 0u32;

let ws_handle = WebSocketHandle::new(client_ws)
.await
.context("failed initiating websocket handle")?;

loop {
match handlers
match handler
.handle_websocket(
ws_handle.clone(),
&req_headers,
Expand Down Expand Up @@ -1895,18 +1898,43 @@ impl ProxyService {
Err(err) => {
tracing::debug!(?err, "websocket handler error");

// Denotes that the connection did not fail, but needs to be retried to
// resole a new target
let ws_retry = is_ws_retry(&err);
// Denotes that the connection did not fail, but the downstream has closed
let ws_hibernate = is_ws_hibernate(&err);

if ws_retry {
if ws_hibernate {
attempts = 0;
} else {
attempts += 1;
}

if attempts > max_attempts
|| (!is_retryable_ws_error(&err) && !ws_retry)
if ws_hibernate {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we handle failure cases if we fail to re-wake the actor when the client disconnects?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't do anything if the client disconnects. When we add waking and if it fails to wake the only thing we can do is log the error

// This should be unreachable because as soon as the actor is
// reconnected to after hibernation the gateway will consume the close
// frame from the client ws stream
ensure!(
!ws_hibernation_close,
"should not be hibernating again after receiving a close frame during hibernation"
);

// After this function returns:
// - the route will be resolved again
// - the websocket will connect to the new downstream target
// - the gateway will continue reading messages from the client ws
// (starting with the message that caused the hibernation to end)
let res = handler
.handle_websocket_hibernation(ws_handle.clone())
.await?;

// Despite receiving a close frame from the client during hibernation
// we are going to reconnect to the actor so that it knows the
// connection has closed
if let HibernationResult::Close = res {
tracing::debug!("starting hibernating websocket close");

ws_hibernation_close = true;
}
} else if attempts > max_attempts
|| !is_retryable_ws_error(&err)
{
tracing::debug!(
?attempts,
Expand All @@ -1929,79 +1957,79 @@ impl ProxyService {

break;
} else {
if !ws_retry {
let backoff = ProxyService::calculate_backoff(
attempts,
initial_interval,
);
let backoff = ProxyService::calculate_backoff(
attempts,
initial_interval,
);

tracing::debug!(
?backoff,
"WebSocket attempt {attempts} failed (service unavailable)"
);
tracing::debug!(
?backoff,
"WebSocket attempt {attempts} failed (service unavailable)"
);

tokio::time::sleep(backoff).await;
}
// Apply backoff for retryable error
tokio::time::sleep(backoff).await;
}

match state
.resolve_route(
&req_host,
&req_path,
&req_method,
state.port_type.clone(),
&req_headers,
true,
)
.await
{
Ok(ResolveRouteOutput::CustomServe(new_handlers)) => {
handlers = new_handlers;
continue;
}
Ok(ResolveRouteOutput::Response(response)) => {
ws_handle
.send(to_hyper_close(Some(str_to_close_frame(
response.message.as_ref(),
))))
.await?;

// Flush to ensure close frame is sent
ws_handle.flush().await?;

// Keep TCP connection open briefly to allow client to process close
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
}
Ok(ResolveRouteOutput::Target(_)) => {
ws_handle
.send(to_hyper_close(Some(err_to_close_frame(
errors::WebSocketTargetChanged.build(),
ray_id,
))))
.await?;
// Retry route resolution
match state
.resolve_route(
&req_host,
&req_path,
&req_method,
state.port_type.clone(),
&req_headers,
true,
)
.await
{
Ok(ResolveRouteOutput::CustomServe(new_handler)) => {
handler = new_handler;
continue;
}
Ok(ResolveRouteOutput::Response(response)) => {
ws_handle
.send(to_hyper_close(Some(str_to_close_frame(
response.message.as_ref(),
))))
.await?;

// Flush to ensure close frame is sent
ws_handle.flush().await?;

// Keep TCP connection open briefly to allow client to process close
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
}
Ok(ResolveRouteOutput::Target(_)) => {
ws_handle
.send(to_hyper_close(Some(err_to_close_frame(
errors::WebSocketTargetChanged.build(),
ray_id,
))))
.await?;

// Flush to ensure close frame is sent
ws_handle.flush().await?;
// Flush to ensure close frame is sent
ws_handle.flush().await?;

// Keep TCP connection open briefly to allow client to process close
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
// Keep TCP connection open briefly to allow client to process close
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;

break;
}
Err(err) => {
ws_handle
.send(to_hyper_close(Some(err_to_close_frame(
err, ray_id,
))))
.await?;
break;
}
Err(err) => {
ws_handle
.send(to_hyper_close(Some(err_to_close_frame(
err, ray_id,
))))
.await?;

// Flush to ensure close frame is sent
ws_handle.flush().await?;
// Flush to ensure close frame is sent
ws_handle.flush().await?;

// Keep TCP connection open briefly to allow client to process close
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
// Keep TCP connection open briefly to allow client to process close
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;

break;
}
break;
}
}
}
Expand Down Expand Up @@ -2509,9 +2537,9 @@ fn is_retryable_ws_error(err: &anyhow::Error) -> bool {
}
}

fn is_ws_retry(err: &anyhow::Error) -> bool {
fn is_ws_hibernate(err: &anyhow::Error) -> bool {
if let Some(rivet_err) = err.chain().find_map(|x| x.downcast_ref::<RivetError>()) {
rivet_err.group() == "guard" && rivet_err.code() == "websocket_service_retry"
rivet_err.group() == "guard" && rivet_err.code() == "websocket_service_hibernate"
} else {
false
}
Expand Down
7 changes: 4 additions & 3 deletions engine/packages/guard-core/src/websocket_handle.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::*;
use futures_util::{SinkExt, StreamExt};
use futures_util::{SinkExt, StreamExt, stream::Peekable};
use hyper::upgrade::Upgraded;
use hyper_tungstenite::HyperWebsocket;
use hyper_tungstenite::tungstenite::Message as WsMessage;
Expand All @@ -8,7 +8,8 @@ use std::sync::Arc;
use tokio::sync::Mutex;
use tokio_tungstenite::WebSocketStream;

pub type WebSocketReceiver = futures_util::stream::SplitStream<WebSocketStream<TokioIo<Upgraded>>>;
pub type WebSocketReceiver =
Peekable<futures_util::stream::SplitStream<WebSocketStream<TokioIo<Upgraded>>>>;

pub type WebSocketSender =
futures_util::stream::SplitSink<WebSocketStream<TokioIo<Upgraded>>, WsMessage>;
Expand All @@ -26,7 +27,7 @@ impl WebSocketHandle {

Ok(Self {
ws_tx: Arc::new(Mutex::new(ws_tx)),
ws_rx: Arc::new(Mutex::new(ws_rx)),
ws_rx: Arc::new(Mutex::new(ws_rx.peekable())),
})
}

Expand Down
Loading
Loading