diff --git a/Cargo.lock b/Cargo.lock index 4431806248..f8892ace8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4240,12 +4240,14 @@ version = "2.0.24-rc.1" dependencies = [ "anyhow", "axum 0.8.4", + "base64 0.22.1", "epoxy", "futures-util", "gasoline", "indexmap 2.10.0", "namespace", "pegboard", + "pegboard-actor-kv", "rivet-api-builder", "rivet-api-types", "rivet-api-util", @@ -4292,6 +4294,7 @@ dependencies = [ "tokio", "tower-http", "tracing", + "urlencoding", "utoipa", "vergen", "vergen-gitcl", diff --git a/engine/artifacts/errors/actor.kv_key_not_found.json b/engine/artifacts/errors/actor.kv_key_not_found.json new file mode 100644 index 0000000000..31e465083b --- /dev/null +++ b/engine/artifacts/errors/actor.kv_key_not_found.json @@ -0,0 +1,5 @@ +{ + "code": "kv_key_not_found", + "group": "actor", + "message": "The KV key does not exist for this actor." +} \ No newline at end of file diff --git a/engine/artifacts/openapi.json b/engine/artifacts/openapi.json index e13da9b43f..ff5cbd4f6b 100644 --- a/engine/artifacts/openapi.json +++ b/engine/artifacts/openapi.json @@ -274,6 +274,49 @@ ] } }, + "/actors/{actor_id}/kv/keys/{key}": { + "get": { + "tags": [ + "actors::kv_get" + ], + "operationId": "actors_kv_get", + "parameters": [ + { + "name": "actor_id", + "in": "path", + "required": true, + "schema": { + "$ref": "#/components/schemas/RivetId" + } + }, + { + "name": "key", + "in": "path", + "required": true, + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ActorsKvGetResponse" + } + } + } + } + }, + "security": [ + { + "bearer_auth": [] + } + ] + } + }, "/datacenters": { "get": { "tags": [ @@ -1020,6 +1063,22 @@ } } }, + "ActorsKvGetResponse": { + "type": "object", + "required": [ + "value", + "update_ts" + ], + "properties": { + "update_ts": { + "type": "integer", + "format": "int64" + }, + "value": { + "type": "string" + } + } + }, "ActorsListNamesResponse": { "type": "object", "required": [ diff --git a/engine/packages/api-peer/Cargo.toml b/engine/packages/api-peer/Cargo.toml index 8b95d7e98a..060e9d6a3a 100644 --- a/engine/packages/api-peer/Cargo.toml +++ b/engine/packages/api-peer/Cargo.toml @@ -8,6 +8,7 @@ license.workspace = true [dependencies] anyhow.workspace = true axum.workspace = true +base64.workspace = true gas.workspace = true epoxy.workspace = true futures-util.workspace = true @@ -27,6 +28,7 @@ tokio.workspace = true tracing.workspace = true namespace.workspace = true pegboard.workspace = true +pegboard-actor-kv.workspace = true universalpubsub.workspace = true uuid.workspace = true utoipa.workspace = true diff --git a/engine/packages/api-peer/src/actors/kv_get.rs b/engine/packages/api-peer/src/actors/kv_get.rs new file mode 100644 index 0000000000..3ce45cf1d9 --- /dev/null +++ b/engine/packages/api-peer/src/actors/kv_get.rs @@ -0,0 +1,67 @@ +use anyhow::*; +use base64::prelude::BASE64_STANDARD; +use base64::Engine; +use pegboard_actor_kv as actor_kv; +use rivet_api_builder::ApiCtx; +use rivet_util::Id; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +pub struct KvGetPath { + pub actor_id: Id, + pub key: String, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct KvGetQuery {} + +#[derive(Serialize, ToSchema)] +#[schema(as = ActorsKvGetResponse)] +pub struct KvGetResponse { + /// Value encoded in base 64. + pub value: String, + pub update_ts: i64, +} + +#[utoipa::path( + get, + operation_id = "actors_kv_get", + path = "/actors/{actor_id}/kv/keys/{key}", + params( + ("actor_id" = Id, Path), + ("key" = String, Path), + ), + responses( + (status = 200, body = KvGetResponse), + ), +)] +#[tracing::instrument(skip_all)] +pub async fn kv_get(ctx: ApiCtx, path: KvGetPath, _query: KvGetQuery) -> Result { + // Decode base64 key + let key_bytes = BASE64_STANDARD + .decode(&path.key) + .context("failed to decode base64 key")?; + + // Get the KV value + let udb = ctx.pools().udb()?; + let (keys, values, metadata) = + actor_kv::get(&*udb, path.actor_id, vec![key_bytes.clone()]).await?; + + // Check if key was found + if keys.is_empty() { + return Err(pegboard::errors::Actor::KvKeyNotFound.build()); + } + + // Encode value as base64 + let value_base64 = BASE64_STANDARD.encode(&values[0]); + + Ok(KvGetResponse { + value: value_base64, + // NOTE: Intentionally uses different name in public API. `create_ts` is actually + // `update_ts`. + update_ts: metadata[0].create_ts, + }) +} diff --git a/engine/packages/api-peer/src/actors/mod.rs b/engine/packages/api-peer/src/actors/mod.rs index ce36036d8f..9451cd59eb 100644 --- a/engine/packages/api-peer/src/actors/mod.rs +++ b/engine/packages/api-peer/src/actors/mod.rs @@ -1,4 +1,5 @@ pub mod create; pub mod delete; +pub mod kv_get; pub mod list; pub mod list_names; diff --git a/engine/packages/api-peer/src/router.rs b/engine/packages/api-peer/src/router.rs index a4f567bb1d..05fdc31fa8 100644 --- a/engine/packages/api-peer/src/router.rs +++ b/engine/packages/api-peer/src/router.rs @@ -26,6 +26,10 @@ pub async fn router( .route("/actors", post(actors::create::create)) .route("/actors/{actor_id}", delete(actors::delete::delete)) .route("/actors/names", get(actors::list_names::list_names)) + .route( + "/actors/{actor_id}/kv/keys/{key}", + get(actors::kv_get::kv_get), + ) // MARK: Runners .route("/runners", get(runners::list)) .route("/runners/names", get(runners::list_names)) diff --git a/engine/packages/api-public/Cargo.toml b/engine/packages/api-public/Cargo.toml index 88cdc25d21..77e6aa2f61 100644 --- a/engine/packages/api-public/Cargo.toml +++ b/engine/packages/api-public/Cargo.toml @@ -30,6 +30,7 @@ serde.workspace = true tokio.workspace = true tower-http.workspace = true tracing.workspace = true +urlencoding.workspace = true utoipa.workspace = true [build-dependencies] diff --git a/engine/packages/api-public/src/actors/kv_get.rs b/engine/packages/api-public/src/actors/kv_get.rs new file mode 100644 index 0000000000..5e9dd3b9db --- /dev/null +++ b/engine/packages/api-public/src/actors/kv_get.rs @@ -0,0 +1,75 @@ +use anyhow::Result; +use axum::response::{IntoResponse, Response}; +use rivet_api_builder::{ + ApiError, + extract::{Extension, Path}, +}; +use rivet_api_util::request_remote_datacenter_raw; +use rivet_util::Id; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; + +use crate::ctx::ApiCtx; + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +pub struct KvGetPath { + pub actor_id: Id, + pub key: String, +} + +#[derive(Serialize, ToSchema)] +#[schema(as = ActorsKvGetResponse)] +pub struct KvGetResponse { + pub value: String, + pub update_ts: i64, +} + +#[utoipa::path( + get, + operation_id = "actors_kv_get", + path = "/actors/{actor_id}/kv/keys/{key}", + params( + ("actor_id" = Id, Path), + ("key" = String, Path), + ), + responses( + (status = 200, body = KvGetResponse), + ), + security(("bearer_auth" = [])), +)] +#[tracing::instrument(skip_all)] +pub async fn kv_get(Extension(ctx): Extension, Path(path): Path) -> Response { + match kv_get_inner(ctx, path).await { + Ok(response) => response, + Err(err) => ApiError::from(err).into_response(), + } +} + +#[tracing::instrument(skip_all)] +async fn kv_get_inner(ctx: ApiCtx, path: KvGetPath) -> Result { + use axum::Json; + + ctx.auth().await?; + + if path.actor_id.label() == ctx.config().dc_label() { + let peer_path = rivet_api_peer::actors::kv_get::KvGetPath { + actor_id: path.actor_id, + key: path.key, + }; + let peer_query = rivet_api_peer::actors::kv_get::KvGetQuery {}; + let res = rivet_api_peer::actors::kv_get::kv_get(ctx.into(), peer_path, peer_query).await?; + + Ok(Json(res).into_response()) + } else { + request_remote_datacenter_raw( + &ctx, + path.actor_id.label(), + &format!("/actors/{}/kv/keys/{}", path.actor_id, urlencoding::encode(&path.key)), + axum::http::Method::GET, + Option::<&()>::None, + Option::<&()>::None, + ) + .await + } +} diff --git a/engine/packages/api-public/src/actors/mod.rs b/engine/packages/api-public/src/actors/mod.rs index 9a563baae5..d1adaf1d36 100644 --- a/engine/packages/api-public/src/actors/mod.rs +++ b/engine/packages/api-public/src/actors/mod.rs @@ -1,6 +1,7 @@ pub mod create; pub mod delete; pub mod get_or_create; +pub mod kv_get; pub mod list; pub mod list_names; pub mod utils; diff --git a/engine/packages/api-public/src/router.rs b/engine/packages/api-public/src/router.rs index 06ef91061c..2b71b35c9b 100644 --- a/engine/packages/api-public/src/router.rs +++ b/engine/packages/api-public/src/router.rs @@ -18,6 +18,7 @@ use crate::{actors, ctx, datacenters, health, metadata, namespaces, runner_confi actors::delete::delete, actors::list_names::list_names, actors::get_or_create::get_or_create, + actors::kv_get::kv_get, runners::list, runners::list_names, namespaces::list, @@ -88,6 +89,10 @@ pub async fn router( "/actors/names", axum::routing::get(actors::list_names::list_names), ) + .route( + "/actors/{actor_id}/kv/keys/{key}", + axum::routing::get(actors::kv_get::kv_get), + ) // MARK: Runners .route("/runners", axum::routing::get(runners::list)) .route("/runners/names", axum::routing::get(runners::list_names)) diff --git a/engine/packages/pegboard/src/errors.rs b/engine/packages/pegboard/src/errors.rs index dbd3173e63..62ab7138a3 100644 --- a/engine/packages/pegboard/src/errors.rs +++ b/engine/packages/pegboard/src/errors.rs @@ -63,6 +63,9 @@ pub enum Actor { namespace: String, runner_name: String, }, + + #[error("kv_key_not_found", "The KV key does not exist for this actor.")] + KvKeyNotFound, } #[derive(RivetError, Debug, Clone, Deserialize, Serialize)] diff --git a/examples/cursors-raw-websocket/src/backend/registry.ts b/examples/cursors-raw-websocket/src/backend/registry.ts index 01ee63eabe..a7c6a1954a 100644 --- a/examples/cursors-raw-websocket/src/backend/registry.ts +++ b/examples/cursors-raw-websocket/src/backend/registry.ts @@ -56,7 +56,7 @@ export const cursorRoom = actor({ }, // Handle WebSocket connections - onWebSocket: async (c, websocket: UniversalWebSocket, { request }) => { + onWebSocket: async (c, websocket: UniversalWebSocket) => { const url = new URL(request.url); const sessionId = url.searchParams.get("sessionId"); diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts b/rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts index dbb09e359b..1bace9e7af 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts @@ -6,10 +6,10 @@ import type { } from "rivetkit"; import { lookupInRegistry } from "rivetkit"; import type { Client } from "rivetkit/client"; -import { - type ActorDriver, - type AnyActorInstance, - type ManagerDriver, +import type { + ActorDriver, + AnyActorInstance, + ManagerDriver, } from "rivetkit/driver-helpers"; import { promiseWithResolvers } from "rivetkit/utils"; import { KEYS } from "./actor-handler-do"; @@ -239,7 +239,6 @@ export class CloudflareActorsActorDriver implements ActorDriver { // Persist data key return Uint8Array.from([1]); } - } export function createCloudflareActorsActorDriverBuilder( diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/manager-driver.ts b/rivetkit-typescript/packages/cloudflare-workers/src/manager-driver.ts index 43013b8825..cfdf78ebbe 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/manager-driver.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/manager-driver.ts @@ -7,6 +7,7 @@ import { type GetOrCreateWithKeyInput, type GetWithKeyInput, generateRandomString, + type ListActorsInput, type ManagerDisplayInformation, type ManagerDriver, WS_PROTOCOL_ACTOR, @@ -348,6 +349,14 @@ export class CloudflareActorsManagerDriver implements ManagerDriver { }; } + async listActors({ c, name }: ListActorsInput): Promise { + logger().warn({ + msg: "listActors not fully implemented for Cloudflare Workers", + name, + }); + return []; + } + // Helper method to build actor output from an ID async #buildActorOutput( c: any, diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/conn-params.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/conn-params.ts index 422b508877..4116a4432c 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/conn-params.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/conn-params.ts @@ -2,7 +2,7 @@ import { actor } from "rivetkit"; export const counterWithParams = actor({ state: { count: 0, initializers: [] as string[] }, - createConnState: (c, opts, params: { name?: string }) => { + createConnState: (c, params: { name?: string }) => { return { name: params.name || "anonymous", }; diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/conn-state.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/conn-state.ts index 2dbb1c172d..92a93963b5 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/conn-state.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/conn-state.ts @@ -16,7 +16,6 @@ export const connStateActor = actor({ // Define connection state createConnState: ( c, - opts, params: { username?: string; role?: string; noCount?: boolean }, ): ConnState => { return { diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/lifecycle.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/lifecycle.ts index 292ca8da91..2fb790e734 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/lifecycle.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/lifecycle.ts @@ -7,13 +7,13 @@ export const counterWithLifecycle = actor({ count: 0, events: [] as string[], }, - createConnState: (c, opts, params: ConnParams) => ({ + createConnState: (c, params: ConnParams) => ({ joinTime: Date.now(), }), onWake: (c) => { c.state.events.push("onWake"); }, - onBeforeConnect: (c, opts, params: ConnParams) => { + onBeforeConnect: (c, params: ConnParams) => { if (params?.trackLifecycle) c.state.events.push("onBeforeConnect"); }, onConnect: (c, conn) => { diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-http-request-properties.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-http-request-properties.ts index 5a6c8f35c8..d441841d8a 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-http-request-properties.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-http-request-properties.ts @@ -1,8 +1,11 @@ -import { type ActorContext, actor } from "rivetkit"; +import { actor, type RequestContext } from "rivetkit"; export const rawHttpRequestPropertiesActor = actor({ actions: {}, - onFetch(ctx: ActorContext, request: Request) { + onRequest( + ctx: RequestContext, + request: Request, + ) { // Extract all relevant Request properties const url = new URL(request.url); const method = request.method; diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-http.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-http.ts index 0d7c31a97a..66b7506900 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-http.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-http.ts @@ -1,11 +1,14 @@ import { Hono } from "hono"; -import { type ActorContext, actor } from "rivetkit"; +import { actor, type RequestContext } from "rivetkit"; export const rawHttpActor = actor({ state: { requestCount: 0, }, - onFetch(ctx: ActorContext, request: Request) { + onRequest( + ctx: RequestContext, + request: Request, + ) { const url = new URL(request.url); const method = request.method; @@ -57,7 +60,7 @@ export const rawHttpNoHandlerActor = actor({ }); export const rawHttpVoidReturnActor = actor({ - onFetch(ctx, request) { + onRequest(ctx, request) { // Intentionally return void to test error handling return undefined as any; }, @@ -107,7 +110,10 @@ export const rawHttpHonoActor = actor({ // Return the router as a var return { router }; }, - onFetch(ctx: ActorContext, request: Request) { + onRequest( + ctx: RequestContext, + request: Request, + ) { // Use the Hono router from vars return ctx.vars.router.fetch(request); }, diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts index 58bd5ce752..0c8e181987 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts @@ -5,7 +5,7 @@ export const rawWebSocketActor = actor({ connectionCount: 0, messageCount: 0, }, - onWebSocket(ctx, websocket, opts) { + onWebSocket(ctx, websocket) { ctx.state.connectionCount = ctx.state.connectionCount + 1; console.log( `[ACTOR] New connection, count: ${ctx.state.connectionCount}`, @@ -51,13 +51,15 @@ export const rawWebSocketActor = actor({ }), ); } else if (parsed.type === "getRequestInfo") { - // Send back the request URL info + // Send back the request URL info if available + const url = ctx.request?.url || "ws://actor/websocket"; + const urlObj = new URL(url); websocket.send( JSON.stringify({ type: "requestInfo", - url: opts.request.url, - pathname: new URL(opts.request.url).pathname, - search: new URL(opts.request.url).search, + url: url, + pathname: urlObj.pathname, + search: urlObj.search, }), ); } else { @@ -93,7 +95,7 @@ export const rawWebSocketActor = actor({ }); export const rawWebSocketBinaryActor = actor({ - onWebSocket(ctx, websocket, opts) { + onWebSocket(ctx, websocket) { // Handle binary data websocket.addEventListener("message", (event: any) => { const data = event.data; diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/request-access.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/request-access.ts index 481526ef6f..2825b4faf8 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/request-access.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/request-access.ts @@ -18,7 +18,7 @@ export const requestAccessActor = actor({ requestMethod: null as string | null, requestHeaders: {} as Record, }, - onFetchRequest: { + onRequestRequest: { hasRequest: false, requestUrl: null as string | null, requestMethod: null as string | null, @@ -31,19 +31,19 @@ export const requestAccessActor = actor({ requestHeaders: {} as Record, }, }, - createConnState: (c, { request }, params: { trackRequest?: boolean }) => { + createConnState: (c, params: { trackRequest?: boolean }) => { // In createConnState, the state isn't available yet. return { trackRequest: params?.trackRequest || false, requestInfo: - params?.trackRequest && request + params?.trackRequest && c.request ? { hasRequest: true, - requestUrl: request.url, - requestMethod: request.method, + requestUrl: c.request.url, + requestMethod: c.request.method, requestHeaders: Object.fromEntries( - request.headers.entries(), + c.request.headers.entries(), ), } : null, @@ -55,16 +55,16 @@ export const requestAccessActor = actor({ c.state.createConnStateRequest = conn.state.requestInfo; } }, - onBeforeConnect: (c, { request }, params) => { + onBeforeConnect: (c, params) => { if (params?.trackRequest) { - if (request) { + if (c.request) { c.state.onBeforeConnectRequest.hasRequest = true; - c.state.onBeforeConnectRequest.requestUrl = request.url; - c.state.onBeforeConnectRequest.requestMethod = request.method; + c.state.onBeforeConnectRequest.requestUrl = c.request.url; + c.state.onBeforeConnectRequest.requestMethod = c.request.method; // Store select headers const headers: Record = {}; - request.headers.forEach((value, key) => { + c.request.headers.forEach((value, key) => { headers[key] = value; }); c.state.onBeforeConnectRequest.requestHeaders = headers; @@ -74,18 +74,18 @@ export const requestAccessActor = actor({ } } }, - onFetch: (c, request) => { + onRequest: (c, request) => { // Store request info - c.state.onFetchRequest.hasRequest = true; - c.state.onFetchRequest.requestUrl = request.url; - c.state.onFetchRequest.requestMethod = request.method; + c.state.onRequestRequest.hasRequest = true; + c.state.onRequestRequest.requestUrl = request.url; + c.state.onRequestRequest.requestMethod = request.method; // Store select headers const headers: Record = {}; request.headers.forEach((value, key) => { headers[key] = value; }); - c.state.onFetchRequest.requestHeaders = headers; + c.state.onRequestRequest.requestHeaders = headers; // Return response with request info return new Response( @@ -101,15 +101,16 @@ export const requestAccessActor = actor({ }, ); }, - onWebSocket: (c, websocket, { request }) => { + onWebSocket: (c, websocket) => { + if (!c.request) throw "Missing request"; // Store request info c.state.onWebSocketRequest.hasRequest = true; - c.state.onWebSocketRequest.requestUrl = request.url; - c.state.onWebSocketRequest.requestMethod = request.method; + c.state.onWebSocketRequest.requestUrl = c.request.url; + c.state.onWebSocketRequest.requestMethod = c.request.method; // Store select headers const headers: Record = {}; - request.headers.forEach((value, key) => { + c.request.headers.forEach((value, key) => { headers[key] = value; }); c.state.onWebSocketRequest.requestHeaders = headers; @@ -118,8 +119,8 @@ export const requestAccessActor = actor({ websocket.send( JSON.stringify({ hasRequest: true, - requestUrl: request.url, - requestMethod: request.method, + requestUrl: c.request.url, + requestMethod: c.request.method, requestHeaders: headers, }), ); @@ -134,7 +135,7 @@ export const requestAccessActor = actor({ return { onBeforeConnect: c.state.onBeforeConnectRequest, createConnState: c.state.createConnStateRequest, - onFetch: c.state.onFetchRequest, + onRequest: c.state.onRequestRequest, onWebSocket: c.state.onWebSocketRequest, }; }, diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/sleep.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/sleep.ts index 8c77c5c1ba..49d963786b 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/sleep.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/sleep.ts @@ -72,7 +72,7 @@ export const sleepWithRawHttp = actor({ onSleep: (c) => { c.state.sleepCount += 1; }, - onFetch: async (c, request) => { + onRequest: async (c, request) => { c.state.requestCount += 1; const url = new URL(request.url); @@ -112,7 +112,7 @@ export const sleepWithRawWebSocket = actor({ onSleep: (c) => { c.state.sleepCount += 1; }, - onWebSocket: (c, websocket: UniversalWebSocket, opts) => { + onWebSocket: (c, websocket: UniversalWebSocket) => { c.state.connectionCount += 1; c.log.info({ msg: "websocket connected", diff --git a/rivetkit-typescript/packages/rivetkit/scripts/dump-openapi.ts b/rivetkit-typescript/packages/rivetkit/scripts/dump-openapi.ts index fa31da47e7..e60a09c436 100644 --- a/rivetkit-typescript/packages/rivetkit/scripts/dump-openapi.ts +++ b/rivetkit-typescript/packages/rivetkit/scripts/dump-openapi.ts @@ -32,6 +32,7 @@ function main() { getWithKey: unimplemented, getOrCreateWithKey: unimplemented, createActor: unimplemented, + listActors: unimplemented, sendRequest: unimplemented, openWebSocket: unimplemented, proxyRequest: unimplemented, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/config.ts b/rivetkit-typescript/packages/rivetkit/src/actor/config.ts index 9281bce706..332ceba3cc 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/config.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/config.ts @@ -3,6 +3,11 @@ import type { UniversalWebSocket } from "@/common/websocket-interface"; import type { Conn } from "./conn/mod"; import type { ActionContext } from "./contexts/action"; import type { ActorContext } from "./contexts/actor"; +import type { CreateConnStateContext } from "./contexts/create-conn-state"; +import type { OnBeforeConnectContext } from "./contexts/on-before-connect"; +import type { OnConnectContext } from "./contexts/on-connect"; +import type { RequestContext } from "./contexts/request"; +import type { WebSocketContext } from "./contexts/websocket"; import type { AnyDatabaseProvider } from "./database"; export type InitContext = ActorContext< @@ -45,7 +50,7 @@ export const ActorConfigSchema = z onConnect: z.function().optional(), onDisconnect: z.function().optional(), onBeforeActionResponse: z.function().optional(), - onFetch: z.function().optional(), + onRequest: z.function().optional(), onWebSocket: z.function().optional(), actions: z.record(z.function()).default({}), state: z.any().optional(), @@ -111,15 +116,6 @@ export const ActorConfigSchema = z }, ); -export interface OnConnectOptions { - /** - * The request object associated with the connection. - * - * @experimental - */ - request?: Request; -} - // Creates state config // // This must have only one or the other or else TState will not be able to be inferred @@ -146,13 +142,12 @@ type CreateConnState< TConnState, TVars, TInput, - TDatabase, + TDatabase extends AnyDatabaseProvider, > = | { connState: TConnState } | { createConnState: ( - c: InitContext, - opts: OnConnectOptions, + c: CreateConnStateContext, params: TConnParams, ) => TConnState | Promise; } @@ -321,15 +316,7 @@ interface BaseActorConfig< * @throws Throw an error to reject the connection */ onBeforeConnect?: ( - c: ActorContext< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TDatabase - >, - opts: OnConnectOptions, + c: OnBeforeConnectContext, params: TConnParams, ) => void | Promise; @@ -343,7 +330,7 @@ interface BaseActorConfig< * @returns Void or a Promise that resolves when connection handling is complete */ onConnect?: ( - c: ActorContext< + c: OnConnectContext< TState, TConnParams, TConnState, @@ -407,11 +394,13 @@ interface BaseActorConfig< * This handler receives raw HTTP requests made to `/actors/{actorName}/http/*` endpoints. * Use this hook to handle custom HTTP patterns, REST APIs, or other HTTP-based protocols. * + * @param c The request context with access to the connection * @param request The raw HTTP request object + * @param opts Additional options * @returns A Response object to send back, or void to continue with default routing */ - onFetch?: ( - c: ActorContext< + onRequest?: ( + c: RequestContext< TState, TConnParams, TConnState, @@ -420,7 +409,6 @@ interface BaseActorConfig< TDatabase >, request: Request, - opts: {}, ) => Response | Promise; /** @@ -429,11 +417,12 @@ interface BaseActorConfig< * This handler receives WebSocket connections made to `/actors/{actorName}/websocket/*` endpoints. * Use this hook to handle custom WebSocket protocols, binary streams, or other WebSocket-based communication. * + * @param c The WebSocket context with access to the connection * @param websocket The raw WebSocket connection - * @param request The original HTTP upgrade request + * @param opts Additional options including the original HTTP upgrade request */ onWebSocket?: ( - c: ActorContext< + c: WebSocketContext< TState, TConnParams, TConnState, @@ -442,7 +431,6 @@ interface BaseActorConfig< TDatabase >, websocket: UniversalWebSocket, - opts: { request: Request }, ) => void | Promise; actions: TActions; @@ -477,7 +465,7 @@ export type ActorConfig< | "onConnect" | "onDisconnect" | "onBeforeActionResponse" - | "onFetch" + | "onRequest" | "onWebSocket" | "state" | "createState" @@ -537,7 +525,7 @@ export type ActorConfigInput< | "onConnect" | "onDisconnect" | "onBeforeActionResponse" - | "onFetch" + | "onRequest" | "onWebSocket" | "state" | "createState" diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts index b53080e8c0..1c7da41096 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts @@ -12,14 +12,27 @@ export enum DriverReadyState { } export interface ConnDriver { + /** The type of driver. Used for debug purposes only. */ + type: string; + + /** + * Unique request ID provided by the underlying provider. If none is + * available for this conn driver, a random UUID is generated. + **/ requestId: string; + + /** ArrayBuffer version of requestId if relevant. */ requestIdBuf: ArrayBuffer | undefined; + + /** + * If the connection can be hibernated. If true, this will allow the actor to go to sleep while the connection is still active. + **/ hibernatable: boolean; sendMessage?( actor: AnyActorInstance, conn: AnyConn, - message: CachedSerializer, + message: CachedSerializer, ): void; /** diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/http.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/http.ts index a332f2bbea..518ed8aed9 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/http.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/http.ts @@ -4,6 +4,7 @@ export type ConnHttpState = Record; export function createHttpSocket(): ConnDriver { return { + type: "http", requestId: crypto.randomUUID(), requestIdBuf: undefined, hibernatable: false, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/raw-request.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/raw-request.ts new file mode 100644 index 0000000000..44a17c5c5f --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/raw-request.ts @@ -0,0 +1,27 @@ +import type { ConnDriver } from "../driver"; +import { DriverReadyState } from "../driver"; + +/** + * Creates a raw HTTP connection driver. + * + * This driver is used for raw HTTP connections that don't use the RivetKit protocol. + * Unlike the standard HTTP driver, this provides connection lifecycle management + * for tracking the HTTP request through the actor's onRequest handler. + */ +export function createRawRequestSocket(): ConnDriver { + return { + type: "raw-request", + requestId: crypto.randomUUID(), + requestIdBuf: undefined, + hibernatable: false, + + disconnect: async () => { + // Noop + }, + + getConnectionReadyState: (): DriverReadyState | undefined => { + // HTTP connections are always considered open until the request completes + return DriverReadyState.OPEN; + }, + }; +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/raw-websocket.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/raw-websocket.ts index 969c575a7f..2f89cf6842 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/raw-websocket.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/raw-websocket.ts @@ -1,7 +1,8 @@ import type { AnyConn } from "@/actor/conn/mod"; import type { AnyActorInstance } from "@/actor/instance/mod"; import type { UniversalWebSocket } from "@/common/websocket-interface"; -import type { ConnDriver, DriverReadyState } from "../driver"; +import { loggerWithoutContext } from "../../log"; +import { type ConnDriver, DriverReadyState } from "../driver"; /** * Creates a raw WebSocket connection driver. @@ -15,10 +16,12 @@ export function createRawWebSocketSocket( requestId: string, requestIdBuf: ArrayBuffer | undefined, hibernatable: boolean, - websocket: UniversalWebSocket, closePromise: Promise, -): ConnDriver { - return { +): { driver: ConnDriver; setWebSocket(ws: UniversalWebSocket): void } { + let websocket: UniversalWebSocket | undefined; + + const driver: ConnDriver = { + type: "raw-websocket", requestId, requestIdBuf, hibernatable, @@ -31,6 +34,13 @@ export function createRawWebSocketSocket( _conn: AnyConn, reason?: string, ) => { + if (!websocket) { + loggerWithoutContext().warn( + "disconnecting raw ws without websocket", + ); + return; + } + // Close socket websocket.close(1000, reason); @@ -39,14 +49,21 @@ export function createRawWebSocketSocket( }, terminate: () => { - (websocket as any).terminate?.(); + (websocket as any)?.terminate?.(); }, getConnectionReadyState: ( _actor: AnyActorInstance, _conn: AnyConn, ): DriverReadyState | undefined => { - return websocket.readyState; + return websocket?.readyState ?? DriverReadyState.CONNECTING; + }, + }; + + return { + driver, + setWebSocket(ws) { + websocket = ws; }, }; } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts index ea4e1cc977..b67959cf16 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts @@ -3,27 +3,38 @@ import type { AnyConn } from "@/actor/conn/mod"; import type { AnyActorInstance } from "@/actor/instance/mod"; import type { CachedSerializer, Encoding } from "@/actor/protocol/serde"; import type * as protocol from "@/schemas/client-protocol/mod"; +import { loggerWithoutContext } from "../../log"; import { type ConnDriver, DriverReadyState } from "../driver"; -export type ConnDriverWebSocketState = {}; +export type ConnDriverWebSocketState = Record; export function createWebSocketSocket( requestId: string, requestIdBuf: ArrayBuffer | undefined, hibernatable: boolean, encoding: Encoding, - websocket: WSContext, closePromise: Promise, -): ConnDriver { - return { +): { driver: ConnDriver; setWebSocket(ws: WSContext): void } { + // Wait for WS to open + let websocket: WSContext | undefined; + + const driver: ConnDriver = { + type: "websocket", requestId, requestIdBuf, hibernatable, sendMessage: ( actor: AnyActorInstance, conn: AnyConn, - message: CachedSerializer, + message: CachedSerializer, ) => { + if (!websocket) { + actor.rLog.warn({ + msg: "websocket not open", + connId: conn.id, + }); + return; + } if (websocket.readyState !== DriverReadyState.OPEN) { actor.rLog.warn({ msg: "attempting to send message to closed websocket, this is likely a bug in RivetKit", @@ -82,6 +93,13 @@ export function createWebSocketSocket( _conn: AnyConn, reason?: string, ) => { + if (!websocket) { + loggerWithoutContext().warn( + "disconnecting ws without websocket", + ); + return; + } + // Close socket websocket.close(1000, reason); @@ -97,7 +115,14 @@ export function createWebSocketSocket( _actor: AnyActorInstance, _conn: AnyConn, ): DriverReadyState | undefined => { - return websocket.readyState; + return websocket?.readyState ?? DriverReadyState.CONNECTING; + }, + }; + + return { + driver, + setWebSocket(ws) { + websocket = ws; }, }; } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts index c4fdf59eee..d5854955c7 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts @@ -1,18 +1,18 @@ import * as cbor from "cbor-x"; -import onChange from "on-change"; -import { isCborSerializable } from "@/common/utils"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; +import { + type ToClient as ToClientJson, + ToClientSchema, +} from "@/schemas/client-protocol-zod/mod"; import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils"; import type { AnyDatabaseProvider } from "../database"; -import * as errors from "../errors"; -import { - ACTOR_INSTANCE_PERSIST_SYMBOL, - type ActorInstance, -} from "../instance/mod"; +import { InternalError } from "../errors"; +import type { ActorInstance } from "../instance/mod"; import type { PersistedConn } from "../instance/persisted"; import { CachedSerializer } from "../protocol/serde"; import type { ConnDriver } from "./driver"; +import { StateManager } from "./state-manager"; export function generateConnRequestId(): string { return crypto.randomUUID(); @@ -22,8 +22,15 @@ export type ConnId = string; export type AnyConn = Conn; +export const CONN_CONNECTED_SYMBOL = Symbol("connected"); export const CONN_PERSIST_SYMBOL = Symbol("persist"); export const CONN_DRIVER_SYMBOL = Symbol("driver"); +export const CONN_ACTOR_SYMBOL = Symbol("actor"); +export const CONN_STATE_ENABLED_SYMBOL = Symbol("stateEnabled"); +export const CONN_PERSIST_RAW_SYMBOL = Symbol("persistRaw"); +export const CONN_HAS_CHANGES_SYMBOL = Symbol("hasChanges"); +export const CONN_MARK_SAVED_SYMBOL = Symbol("markSaved"); +export const CONN_SEND_MESSAGE_SYMBOL = Symbol("sendMessage"); /** * Represents a client connection to a actor. @@ -38,34 +45,40 @@ export class Conn { // TODO: Remove this cyclical reference #actor: ActorInstance; - /** - * The proxied state that notifies of changes automatically. - * - * Any data that should be stored indefinitely should be held within this - * object. - * - * This will only be persisted if using hibernatable WebSockets. If not, - * this is just used to hole state. - */ - [CONN_PERSIST_SYMBOL]!: PersistedConn; - - /** Raw persist object without the proxy wrapper */ - #persistRaw: PersistedConn; - - /** Track if this connection's state has changed */ - #changed = false; + // MARK: - Managers + #stateManager!: StateManager; /** * If undefined, then nothing is connected to this. */ [CONN_DRIVER_SYMBOL]?: ConnDriver; - public get params(): CP { - return this[CONN_PERSIST_SYMBOL].params; + // MARK: - Public Getters + + get [CONN_ACTOR_SYMBOL](): ActorInstance { + return this.#actor; } - public get stateEnabled() { - return this.#actor.connStateEnabled; + /** Connections exist before being connected to an actor. If true, this connection has been connected. */ + [CONN_CONNECTED_SYMBOL] = false; + + #assertConnected() { + if (!this[CONN_CONNECTED_SYMBOL]) + throw new InternalError( + "Connection not connected yet. This happens when trying to use the connection in onBeforeConnect or createConnState.", + ); + } + + get [CONN_PERSIST_SYMBOL](): PersistedConn { + return this.#stateManager.persist; + } + + get params(): CP { + return this.#stateManager.params; + } + + get [CONN_STATE_ENABLED_SYMBOL](): boolean { + return this.#stateManager.stateEnabled; } /** @@ -73,11 +86,8 @@ export class Conn { * * Throws an error if the state is not enabled. */ - public get state(): CS { - this.#validateStateEnabled(); - if (!this[CONN_PERSIST_SYMBOL].state) - throw new Error("state should exists"); - return this[CONN_PERSIST_SYMBOL].state; + get state(): CS { + return this.#stateManager.state; } /** @@ -85,16 +95,15 @@ export class Conn { * * Throws an error if the state is not enabled. */ - public set state(value: CS) { - this.#validateStateEnabled(); - this[CONN_PERSIST_SYMBOL].state = value; + set state(value: CS) { + this.#stateManager.state = value; } /** * Unique identifier for the connection. */ - public get id(): ConnId { - return this[CONN_PERSIST_SYMBOL].connId; + get id(): ConnId { + return this.#stateManager.persist.connId; } /** @@ -102,17 +111,17 @@ export class Conn { * * If the underlying connection can hibernate. */ - public get isHibernatable(): boolean { - if (!this[CONN_PERSIST_SYMBOL].hibernatableRequestId) { + get isHibernatable(): boolean { + const hibernatableRequestId = + this.#stateManager.persist.hibernatableRequestId; + if (!hibernatableRequestId) { return false; } return ( - (this.#actor as any)[ - ACTOR_INSTANCE_PERSIST_SYMBOL - ].hibernatableConns.findIndex((conn: any) => + this.#actor.persist.hibernatableConns.findIndex((conn: any) => arrayBuffersEqual( conn.hibernatableRequestId, - this[CONN_PERSIST_SYMBOL].hibernatableRequestId!, + hibernatableRequestId, ), ) > -1 ); @@ -121,8 +130,8 @@ export class Conn { /** * Timestamp of the last time the connection was seen, i.e. the last time the connection was active and checked for liveness. */ - public get lastSeen(): number { - return this[CONN_PERSIST_SYMBOL].lastSeen; + get lastSeen(): number { + return this.#stateManager.persist.lastSeen; } /** @@ -132,94 +141,37 @@ export class Conn { * * @protected */ - public constructor( + constructor( actor: ActorInstance, persist: PersistedConn, ) { this.#actor = actor; - this.#persistRaw = persist; - this.#setupPersistProxy(persist); - } - - /** - * Sets up the proxy for connection persistence with change tracking - */ - #setupPersistProxy(persist: PersistedConn) { - // If this can't be proxied, return raw value - if (persist === null || typeof persist !== "object") { - this[CONN_PERSIST_SYMBOL] = persist; - return; - } - - // Listen for changes to the object - this[CONN_PERSIST_SYMBOL] = onChange( - persist, - ( - path: string, - value: any, - _previousValue: any, - _applyData: any, - ) => { - // Validate CBOR serializability for state changes - if (path.startsWith("state")) { - let invalidPath = ""; - if ( - !isCborSerializable( - value, - (invalidPathPart: string) => { - invalidPath = invalidPathPart; - }, - "", - ) - ) { - throw new errors.InvalidStateType({ - path: path + (invalidPath ? `.${invalidPath}` : ""), - }); - } - } - - this.#changed = true; - this.#actor.rLog.debug({ - msg: "conn onChange triggered", - connId: this.id, - path, - }); - - // Notify actor that this connection has changed - this.#actor.markConnChanged(this); - }, - { ignoreDetached: true }, - ); + this.#stateManager = new StateManager(this); + this.#stateManager.initPersistProxy(persist); } /** * Returns whether this connection has unsaved changes */ - get hasChanges(): boolean { - return this.#changed; + [CONN_HAS_CHANGES_SYMBOL](): boolean { + return this.#stateManager.hasChanges(); } /** * Marks changes as saved */ - markSaved() { - this.#changed = false; + [CONN_MARK_SAVED_SYMBOL]() { + this.#stateManager.markSaved(); } /** * Gets the raw persist data for serialization */ - get persistRaw(): PersistedConn { - return this.#persistRaw; - } - - #validateStateEnabled() { - if (!this.stateEnabled) { - throw new errors.ConnStateNotEnabled(); - } + get [CONN_PERSIST_RAW_SYMBOL](): PersistedConn { + return this.#stateManager.persistRaw; } - public sendMessage(message: CachedSerializer) { + [CONN_SEND_MESSAGE_SYMBOL](message: CachedSerializer) { if (this[CONN_DRIVER_SYMBOL]) { const driver = this[CONN_DRIVER_SYMBOL]; if (driver.sendMessage) { @@ -245,25 +197,41 @@ export class Conn { * @param args - The arguments for the event. * @see {@link https://rivet.dev/docs/events|Events Documentation} */ - public send(eventName: string, ...args: unknown[]) { + send(eventName: string, ...args: unknown[]) { + this.#assertConnected(); + this.#actor.inspector.emitter.emit("eventFired", { type: "event", eventName, args, connId: this.id, }); - this.sendMessage( - new CachedSerializer( - { + const eventData = { name: eventName, args }; + this[CONN_SEND_MESSAGE_SYMBOL]( + new CachedSerializer( + eventData, + TO_CLIENT_VERSIONED, + ToClientSchema, + // JSON: args is the raw value (array of arguments) + (value): ToClientJson => ({ body: { - tag: "Event", + tag: "Event" as const, val: { - name: eventName, - args: bufferToArrayBuffer(cbor.encode(args)), + name: value.name, + args: value.args, }, }, - }, - TO_CLIENT_VERSIONED, + }), + // BARE/CBOR: args needs to be CBOR-encoded to ArrayBuffer + (value): protocol.ToClient => ({ + body: { + tag: "Event" as const, + val: { + name: value.name, + args: bufferToArrayBuffer(cbor.encode(value.args)), + }, + }, + }), ), ); } @@ -273,7 +241,7 @@ export class Conn { * * @param reason - The reason for disconnection. */ - public async disconnect(reason?: string) { + async disconnect(reason?: string) { if (this[CONN_DRIVER_SYMBOL]) { const driver = this[CONN_DRIVER_SYMBOL]; if (driver.disconnect) { @@ -285,7 +253,7 @@ export class Conn { }); } - this.#actor.connDisconnected(this, true); + this.#actor.connectionManager.connDisconnected(this); } else { this.#actor.rLog.warn({ msg: "missing connection driver state for disconnect", diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/state-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/state-manager.ts new file mode 100644 index 0000000000..a79895dea4 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/state-manager.ts @@ -0,0 +1,141 @@ +import onChange from "on-change"; +import { isCborSerializable } from "@/common/utils"; +import * as errors from "../errors"; +import type { PersistedConn } from "../instance/persisted"; +import { CONN_ACTOR_SYMBOL, CONN_STATE_ENABLED_SYMBOL, type Conn } from "./mod"; + +/** + * Manages connection state persistence, proxying, and change tracking. + * Handles automatic state change detection for connection-specific state. + */ +export class StateManager { + #conn: Conn; + + // State tracking + #persist!: PersistedConn; + #persistRaw!: PersistedConn; + #changed = false; + + constructor(conn: Conn) { + this.#conn = conn; + } + + // MARK: - Public API + + get persist(): PersistedConn { + return this.#persist; + } + + get persistRaw(): PersistedConn { + return this.#persistRaw; + } + + get changed(): boolean { + return this.#changed; + } + + get stateEnabled(): boolean { + return this.#conn[CONN_ACTOR_SYMBOL].connStateEnabled; + } + + get state(): CS { + this.#validateStateEnabled(); + if (!this.#persist.state) throw new Error("state should exists"); + return this.#persist.state; + } + + set state(value: CS) { + this.#validateStateEnabled(); + this.#persist.state = value; + } + + get params(): CP { + return this.#persist.params; + } + + // MARK: - Initialization + + /** + * Creates proxy for persist object that handles automatic state change detection. + */ + initPersistProxy(target: PersistedConn) { + // Set raw persist object + this.#persistRaw = target; + + // If this can't be proxied, return raw value + if (target === null || typeof target !== "object") { + this.#persist = target; + return; + } + + // Listen for changes to the object + this.#persist = onChange( + target, + ( + path: string, + value: any, + _previousValue: any, + _applyData: any, + ) => { + this.#handleChange(path, value); + }, + { ignoreDetached: true }, + ); + } + + // MARK: - Change Management + + /** + * Returns whether this connection has unsaved changes + */ + hasChanges(): boolean { + return this.#changed; + } + + /** + * Marks changes as saved + */ + markSaved() { + this.#changed = false; + } + + // MARK: - Private Helpers + + #validateStateEnabled() { + if (!this.stateEnabled) { + throw new errors.ConnStateNotEnabled(); + } + } + + #handleChange(path: string, value: any) { + // Validate CBOR serializability for state changes + if (path.startsWith("state")) { + let invalidPath = ""; + if ( + !isCborSerializable( + value, + (invalidPathPart: string) => { + invalidPath = invalidPathPart; + }, + "", + ) + ) { + throw new errors.InvalidStateType({ + path: path + (invalidPath ? `.${invalidPath}` : ""), + }); + } + } + + this.#changed = true; + this.#conn[CONN_ACTOR_SYMBOL].rLog.debug({ + msg: "conn onChange triggered", + connId: this.#conn.id, + path, + }); + + // Notify actor that this connection has changed + this.#conn[CONN_ACTOR_SYMBOL].connectionManager.markConnChanged( + this.#conn, + ); + } +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/action.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/action.ts index a8953d883e..9e99862a69 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/action.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/action.ts @@ -1,17 +1,10 @@ -import type { ActorKey } from "@/actor/mod"; -import type { Client } from "@/client/client"; -import type { Logger } from "@/common/log"; -import type { Registry } from "@/registry/mod"; -import type { Conn, ConnId } from "../conn/mod"; -import type { AnyDatabaseProvider, InferDatabaseClient } from "../database"; -import type { SaveStateOptions } from "../instance/state-manager"; -import type { Schedule } from "../schedule"; -import type { ActorContext } from "./actor"; +import type { Conn } from "../conn/mod"; +import type { AnyDatabaseProvider } from "../database"; +import type { ActorInstance } from "../instance/mod"; +import { ConnContext } from "./conn"; /** * Context for a remote procedure call. - * - * @typeParam A Actor this action belongs to */ export class ActionContext< TState, @@ -20,159 +13,11 @@ export class ActionContext< TVars, TInput, TDatabase extends AnyDatabaseProvider, -> { - #actorContext: ActorContext< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TDatabase - >; - - /** - * Should not be called directly. - * - * @param actorContext - The actor context - * @param conn - The connection associated with the action - */ - constructor( - actorContext: ActorContext< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TDatabase - >, - public readonly conn: Conn< - TState, - TConnParams, - TConnState, - TVars, - TInput, - TDatabase - >, - ) { - this.#actorContext = actorContext; - } - - /** - * Get the actor state - */ - get state(): TState { - return this.#actorContext.state; - } - - /** - * Get the actor variables - */ - get vars(): TVars { - return this.#actorContext.vars; - } - - /** - * Broadcasts an event to all connected clients. - */ - broadcast(name: string, ...args: any[]): void { - this.#actorContext.broadcast(name, ...args); - } - - /** - * Gets the logger instance. - */ - get log(): Logger { - return this.#actorContext.log; - } - - /** - * Gets actor ID. - */ - get actorId(): string { - return this.#actorContext.actorId; - } - - /** - * Gets the actor name. - */ - get name(): string { - return this.#actorContext.name; - } - - /** - * Gets the actor key. - */ - get key(): ActorKey { - return this.#actorContext.key; - } - - /** - * Gets the region. - */ - get region(): string { - return this.#actorContext.region; - } - - /** - * Gets the scheduler. - */ - get schedule(): Schedule { - return this.#actorContext.schedule; - } - - /** - * Gets the map of connections. - */ - get conns(): Map< - ConnId, - Conn - > { - return this.#actorContext.conns; - } - - /** - * Returns the client for the given registry. - */ - client>(): Client { - return this.#actorContext.client(); - } - - /** - * @experimental - */ - get db(): InferDatabaseClient { - return this.#actorContext.db; - } - - /** - * Forces the state to get saved. - */ - async saveState(opts: SaveStateOptions): Promise { - return this.#actorContext.saveState(opts); - } - - /** - * Prevents the actor from sleeping until promise is complete. - */ - waitUntil(promise: Promise): void { - this.#actorContext.waitUntil(promise); - } - - /** - * AbortSignal that fires when the actor is stopping. - */ - get abortSignal(): AbortSignal { - return this.#actorContext.abortSignal; - } - - /** - * Forces the actor to sleep. - * - * Not supported on all drivers. - * - * @experimental - */ - sleep() { - this.#actorContext.sleep(); - } -} +> extends ConnContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase +> {} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/actor.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/actor.ts index 0096580854..dde24c45bd 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/actor.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/actor.ts @@ -60,7 +60,7 @@ export class ActorContext< * @param args - The arguments to send with the event. */ broadcast>(name: string, ...args: Args): void { - this.#actor.broadcast(name, ...args); + this.#actor.eventManager.broadcast(name, ...args); return; } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/conn-init.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/conn-init.ts new file mode 100644 index 0000000000..801bd86046 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/conn-init.ts @@ -0,0 +1,31 @@ +import type { AnyDatabaseProvider } from "../database"; +import type { ActorInstance } from "../instance/mod"; +import { ActorContext } from "./actor"; + +/** + * Base context for connection initialization handlers. + * Extends ActorContext with request-specific functionality for connection lifecycle events. + */ +export abstract class ConnInitContext< + TState, + TVars, + TInput, + TDatabase extends AnyDatabaseProvider, +> extends ActorContext { + /** + * The incoming request that initiated the connection. + * May be undefined for connections initiated without a direct HTTP request. + */ + public readonly request: Request | undefined; + + /** + * @internal + */ + constructor( + actor: ActorInstance, + request: Request | undefined, + ) { + super(actor); + this.request = request; + } +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/conn.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/conn.ts new file mode 100644 index 0000000000..725d40b418 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/conn.ts @@ -0,0 +1,48 @@ +import type { Conn } from "../conn/mod"; +import type { AnyDatabaseProvider } from "../database"; +import type { ActorInstance } from "../instance/mod"; +import { ActorContext } from "./actor"; + +/** + * Base context for connection-based handlers. + * Extends ActorContext with connection-specific functionality. + */ +export abstract class ConnContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase extends AnyDatabaseProvider, +> extends ActorContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase +> { + /** + * @internal + */ + constructor( + actor: ActorInstance< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase + >, + public readonly conn: Conn< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase + >, + ) { + super(actor); + } +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/create-conn-state.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/create-conn-state.ts new file mode 100644 index 0000000000..78e9c78e6c --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/create-conn-state.ts @@ -0,0 +1,13 @@ +import type { AnyDatabaseProvider } from "../database"; +import { ConnInitContext } from "./conn-init"; + +/** + * Context for the createConnState lifecycle hook. + * Called to initialize connection-specific state when a connection is created. + */ +export class CreateConnStateContext< + TState, + TVars, + TInput, + TDatabase extends AnyDatabaseProvider, +> extends ConnInitContext {} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/on-before-connect.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/on-before-connect.ts new file mode 100644 index 0000000000..759c210052 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/on-before-connect.ts @@ -0,0 +1,13 @@ +import type { AnyDatabaseProvider } from "../database"; +import { ConnInitContext } from "./conn-init"; + +/** + * Context for the onBeforeConnect lifecycle hook. + * Called before a connection is established, allowing for validation and early rejection. + */ +export class OnBeforeConnectContext< + TState, + TVars, + TInput, + TDatabase extends AnyDatabaseProvider, +> extends ConnInitContext {} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/on-connect.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/on-connect.ts new file mode 100644 index 0000000000..14493f3c08 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/on-connect.ts @@ -0,0 +1,22 @@ +import type { AnyDatabaseProvider } from "../database"; +import { ConnContext } from "./conn"; + +/** + * Context for the onConnect lifecycle hook. + * Called when a connection is successfully established. + */ +export class OnConnectContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase extends AnyDatabaseProvider, +> extends ConnContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase +> {} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/request.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/request.ts new file mode 100644 index 0000000000..1615687f61 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/request.ts @@ -0,0 +1,48 @@ +import type { Conn } from "../conn/mod"; +import type { AnyDatabaseProvider } from "../database"; +import type { ActorInstance } from "../instance/mod"; +import { ConnContext } from "./conn"; + +/** + * Context for raw HTTP request handlers (onRequest). + */ +export class RequestContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase extends AnyDatabaseProvider, +> extends ConnContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase +> { + /** + * The incoming HTTP request. + * May be undefined for request contexts initiated without a direct HTTP request. + */ + public readonly request: Request | undefined; + + /** + * @internal + */ + constructor( + actor: ActorInstance< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase + >, + conn: Conn, + request?: Request, + ) { + super(actor, conn); + this.request = request; + } +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/websocket.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/websocket.ts new file mode 100644 index 0000000000..f368330971 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/websocket.ts @@ -0,0 +1,48 @@ +import type { Conn } from "../conn/mod"; +import type { AnyDatabaseProvider } from "../database"; +import type { ActorInstance } from "../instance/mod"; +import { ConnContext } from "./conn"; + +/** + * Context for raw WebSocket handlers (onWebSocket). + */ +export class WebSocketContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase extends AnyDatabaseProvider, +> extends ConnContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase +> { + /** + * The incoming HTTP request that initiated the WebSocket upgrade. + * May be undefined for WebSocket connections initiated without a direct HTTP request. + */ + public readonly request: Request | undefined; + + /** + * @internal + */ + constructor( + actor: ActorInstance< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase + >, + conn: Conn, + request?: Request, + ) { + super(actor, conn); + this.request = request; + } +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/errors.ts b/rivetkit-typescript/packages/rivetkit/src/actor/errors.ts index 07a517dbd0..8b7b9bfd05 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/errors.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/errors.ts @@ -317,12 +317,12 @@ export class DatabaseNotEnabled extends ActorError { } } -export class FetchHandlerNotDefined extends ActorError { +export class RequestHandlerNotDfeined extends ActorError { constructor() { super( "handler", - "fetch_not_defined", - "Raw HTTP handler not defined. Actor must implement `onFetch` to handle raw HTTP requests. (https://www.rivet.dev/docs/actors/fetch-and-websocket-handler/)", + "request_not_defined", + "Raw request handler not defined. Actor must implement `onRequest` to handle raw HTTP requests. (https://www.rivet.dev/docs/actors/fetch-and-websocket-handler/)", { public: true }, ); this.statusCode = 404; @@ -341,12 +341,12 @@ export class WebSocketHandlerNotDefined extends ActorError { } } -export class InvalidFetchResponse extends ActorError { +export class InvalidRequestHandlerResponse extends ActorError { constructor() { super( "handler", - "invalid_fetch_response", - "Actor's onFetch handler must return a Response object. Returning void/undefined is not allowed. (https://www.rivet.dev/docs/actors/fetch-and-websocket-handler/)", + "invalid_request_handler_response", + "Actor's onRequest handler must return a Response object. Returning void/undefined is not allowed. (https://www.rivet.dev/docs/actors/fetch-and-websocket-handler/)", { public: true }, ); this.statusCode = 500; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts index 23d6481ed0..8ad2847c49 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts @@ -1,18 +1,27 @@ import * as cbor from "cbor-x"; -import { arrayBuffersEqual, idToStr, stringifyError } from "@/utils"; -import type { OnConnectOptions } from "../config"; +import invariant from "invariant"; +import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; +import { ToClientSchema } from "@/schemas/client-protocol-zod/mod"; +import { arrayBuffersEqual, stringifyError } from "@/utils"; import type { ConnDriver } from "../conn/driver"; import { + CONN_CONNECTED_SYMBOL, CONN_DRIVER_SYMBOL, + CONN_MARK_SAVED_SYMBOL, + CONN_PERSIST_RAW_SYMBOL, CONN_PERSIST_SYMBOL, + CONN_SEND_MESSAGE_SYMBOL, Conn, type ConnId, } from "../conn/mod"; +import { CreateConnStateContext } from "../contexts/create-conn-state"; +import { OnBeforeConnectContext } from "../contexts/on-before-connect"; +import { OnConnectContext } from "../contexts/on-connect"; import type { AnyDatabaseProvider } from "../database"; -import type { ActorDriver } from "../driver"; +import { CachedSerializer } from "../protocol/serde"; import { deadline } from "../utils"; import { makeConnKey } from "./kv"; -import { ACTOR_INSTANCE_PERSIST_SYMBOL, type ActorInstance } from "./mod"; +import type { ActorInstance } from "./mod"; import type { PersistedConn } from "./persisted"; /** @@ -63,15 +72,16 @@ export class ConnectionManager< } // MARK: - Connection Lifecycle - /** - * Creates a new connection or reconnects an existing hibernatable connection. + * Handles pre-connection logic (i.e. auth & create state) before actually connecting the connection. */ - async createConn( + async prepareConn( driver: ConnDriver, params: CP, - request?: Request, + request: Request | undefined, ): Promise> { + this.#actor.assertReady(); + // Check for hibernatable websocket reconnection if (driver.requestIdBuf && driver.hibernatable) { const existingConn = this.#findHibernatableConn( @@ -84,76 +94,121 @@ export class ConnectionManager< } // Create new connection - return await this.#createNewConn(driver, params, request); + const persist = this.#actor.persist; + if (this.#actor.config.onBeforeConnect) { + const ctx = new OnBeforeConnectContext(this.#actor, request); + await this.#actor.config.onBeforeConnect(ctx, params); + } + + // Create connection state if enabled + let connState: CS | undefined; + if (this.#actor.connStateEnabled) { + connState = await this.#createConnState(params, request); + } + + // Create connection persist data + const connPersist: PersistedConn = { + connId: crypto.randomUUID(), + params: params, + state: connState as CS, + lastSeen: Date.now(), + subscriptions: [], + }; + + // Check if hibernatable + if (driver.requestIdBuf) { + const isHibernatable = this.#isHibernatableRequest( + driver.requestIdBuf, + ); + if (isHibernatable) { + connPersist.hibernatableRequestId = driver.requestIdBuf; + } + } + + // Create connection instance + const conn = new Conn(this.#actor, connPersist); + conn[CONN_DRIVER_SYMBOL] = driver; + + return conn; } /** - * Handle connection disconnection. - * Clean disconnects remove the connection immediately. - * Unclean disconnects keep the connection for potential reconnection. + * Adds a connection form prepareConn to the actor and calls onConnect. + * + * This method is intentionally not async since it needs to be called in + * `onOpen` for WebSockets. If this is async, the order of open events will + * be messed up and cause race conditions that can drop WebSocket messages. + * So all async work in prepareConn. */ - async connDisconnected( - conn: Conn, - wasClean: boolean, - actorDriver: ActorDriver, - eventManager: any, // EventManager type - ) { - if (wasClean) { - // Clean disconnect - remove immediately - await this.removeConn(conn, actorDriver, eventManager); - } else { - // Unclean disconnect - keep for reconnection - this.#handleUncleanDisconnect(conn); - } + connectConn(conn: Conn) { + invariant(!this.#connections.has(conn.id), "conn already connected"); + + this.#connections.set(conn.id, conn); + + this.#changedConnections.add(conn.id); + + this.#callOnConnect(conn); + + this.#actor.inspector.emitter.emit("connectionUpdated"); + + this.#actor.resetSleepTimer(); + + conn[CONN_CONNECTED_SYMBOL] = true; + + // TODO: Only do this for action messages + // Send init message + const initData = { actorId: this.#actor.id, connectionId: conn.id }; + conn[CONN_SEND_MESSAGE_SYMBOL]( + new CachedSerializer( + initData, + TO_CLIENT_VERSIONED, + ToClientSchema, + // JSON: identity conversion (no nested data to encode) + (value) => ({ + body: { + tag: "Init" as const, + val: value, + }, + }), + // BARE/CBOR: identity conversion (no nested data to encode) + (value) => ({ + body: { + tag: "Init" as const, + val: value, + }, + }), + ), + ); } /** - * Removes a connection and cleans up its resources. + * Handle connection disconnection. + * + * This is called by `Conn.disconnect`. This should not call `Conn.disconnect.` */ - async removeConn( - conn: Conn, - actorDriver: ActorDriver, - eventManager: any, // EventManager type - ) { - // Remove from KV storage - const key = makeConnKey(conn.id); - try { - await actorDriver.kvBatchDelete(this.#actor.id, [key]); - this.#actor.rLog.debug({ - msg: "removed connection from KV", - connId: conn.id, - }); - } catch (err) { - this.#actor.rLog.error({ - msg: "kvBatchDelete failed for conn", - err: stringifyError(err), - }); - } - + async connDisconnected(conn: Conn) { // Remove from tracking this.#connections.delete(conn.id); this.#changedConnections.delete(conn.id); this.#actor.rLog.debug({ msg: "removed conn", connId: conn.id }); - // Clean up subscriptions via EventManager - if (eventManager) { - for (const eventName of [...conn.subscriptions.values()]) { - eventManager.removeSubscription(eventName, conn, true); - } + for (const eventName of [...conn.subscriptions.values()]) { + this.#actor.eventManager.removeSubscription(eventName, conn, true); } - // Emit events and call lifecycle hooks + this.#actor.resetSleepTimer(); + this.#actor.inspector.emitter.emit("connectionUpdated"); - const config = (this.#actor as any).config; - if (config?.onDisconnect) { + // Trigger disconnect + if (this.#actor.config.onDisconnect) { try { - const result = config.onDisconnect( + const result = this.#actor.config.onDisconnect( this.#actor.actorContext, conn, ); if (result instanceof Promise) { - result.catch((error: any) => { + result.catch((error) => { this.#actor.rLog.error({ msg: "error in `onDisconnect`", error: stringifyError(error), @@ -167,6 +222,34 @@ export class ConnectionManager< }); } } + + // Remove from KV storage + const key = makeConnKey(conn.id); + try { + await this.#actor.driver.kvBatchDelete(this.#actor.id, [key]); + this.#actor.rLog.debug({ + msg: "removed connection from KV", + connId: conn.id, + }); + } catch (err) { + this.#actor.rLog.error({ + msg: "kvBatchDelete failed for conn", + err: stringifyError(err), + }); + } + } + + /** + * Utilify funtion for call sites that don't need a separate prepare and connect phase. + */ + async prepareAndConnectConn( + driver: ConnDriver, + params: CP, + request: Request | undefined, + ): Promise> { + const conn = await this.prepareConn(driver, params, request); + this.connectConn(conn); + return conn; } // MARK: - Persistence @@ -174,10 +257,7 @@ export class ConnectionManager< /** * Restores connections from persisted data during actor initialization. */ - restoreConnections( - connections: PersistedConn[], - eventManager: any, // EventManager type - ) { + restoreConnections(connections: PersistedConn[]) { for (const connPersist of connections) { // Create connection instance const conn = new Conn( @@ -188,7 +268,11 @@ export class ConnectionManager< // Restore subscriptions for (const sub of connPersist.subscriptions) { - eventManager.addSubscription(sub.eventName, conn, true); + this.#actor.eventManager.addSubscription( + sub.eventName, + conn, + true, + ); } } } @@ -202,9 +286,9 @@ export class ConnectionManager< for (const connId of this.#changedConnections) { const conn = this.#connections.get(connId); if (conn) { - const connData = cbor.encode(conn.persistRaw); + const connData = cbor.encode(conn[CONN_PERSIST_RAW_SYMBOL]); entries.push([makeConnKey(connId), connData]); - conn.markSaved(); + conn[CONN_MARK_SAVED_SYMBOL](); } } @@ -261,96 +345,25 @@ export class ConnectionManager< } } - async #createNewConn( - driver: ConnDriver, - params: CP, - request: Request | undefined, - ): Promise> { - const config = this.#actor.config; - const persist = (this.#actor as any)[ACTOR_INSTANCE_PERSIST_SYMBOL]; - // Prepare connection state - let connState: CS | undefined; - - const onBeforeConnectOpts = { - request, - } satisfies OnConnectOptions; - - // Call onBeforeConnect hook - if (config.onBeforeConnect) { - await config.onBeforeConnect( - this.#actor.actorContext, - onBeforeConnectOpts, - params, - ); - } - - // Create connection state if enabled - if ((this.#actor as any).connStateEnabled) { - connState = await this.#createConnState( - config, - onBeforeConnectOpts, - params, - ); - } - - // Create connection persist data - const connPersist: PersistedConn = { - connId: crypto.randomUUID(), - params: params, - state: connState as CS, - lastSeen: Date.now(), - subscriptions: [], - }; - - // Check if hibernatable - if (driver.requestIdBuf) { - const isHibernatable = this.#isHibernatableRequest( - driver.requestIdBuf, - persist, - ); - if (isHibernatable) { - connPersist.hibernatableRequestId = driver.requestIdBuf; - } - } - - // Create connection instance - const conn = new Conn(this.#actor, connPersist); - conn[CONN_DRIVER_SYMBOL] = driver; - this.#connections.set(conn.id, conn); - - // Mark as changed for persistence - this.#changedConnections.add(conn.id); - - // Call onConnect lifecycle hook - if (config.onConnect) { - this.#callOnConnect(config, conn); - } - - this.#actor.inspector.emitter.emit("connectionUpdated"); - - return conn; - } - async #createConnState( - config: any, - opts: OnConnectOptions, params: CP, + request: Request | undefined, ): Promise { - if ("createConnState" in config) { - const dataOrPromise = config.createConnState( - this.#actor.actorContext, - opts, + if ("createConnState" in this.#actor.config) { + const ctx = new CreateConnStateContext(this.#actor, request); + const dataOrPromise = this.#actor.config.createConnState( + ctx, params, ); if (dataOrPromise instanceof Promise) { return await deadline( dataOrPromise, - config.options.createConnStateTimeout, + this.#actor.config.options.createConnStateTimeout, ); } return dataOrPromise; - } else if ("connState" in config) { - return structuredClone(config.connState); + } else if ("connState" in this.#actor.config) { + return structuredClone(this.#actor.config.connState); } throw new Error( @@ -358,46 +371,38 @@ export class ConnectionManager< ); } - #isHibernatableRequest(requestIdBuf: ArrayBuffer, persist: any): boolean { + #isHibernatableRequest(requestIdBuf: ArrayBuffer): boolean { return ( - persist.hibernatableConns.findIndex((conn: any) => + this.#actor.persist.hibernatableConns.findIndex((conn) => arrayBuffersEqual(conn.hibernatableRequestId, requestIdBuf), ) !== -1 ); } - #callOnConnect(config: any, conn: Conn) { - try { - const result = config.onConnect(this.#actor.actorContext, conn); - if (result instanceof Promise) { - deadline(result, config.options.onConnectTimeout).catch( - (error: any) => { + #callOnConnect(conn: Conn) { + if (this.#actor.config.onConnect) { + try { + const ctx = new OnConnectContext(this.#actor, conn); + const result = this.#actor.config.onConnect(ctx, conn); + if (result instanceof Promise) { + deadline( + result, + this.#actor.config.options.onConnectTimeout, + ).catch((error) => { this.#actor.rLog.error({ msg: "error in `onConnect`, closing socket", error, }); conn?.disconnect("`onConnect` failed"); - }, - ); + }); + } + } catch (error) { + this.#actor.rLog.error({ + msg: "error in `onConnect`", + error: stringifyError(error), + }); + conn?.disconnect("`onConnect` failed"); } - } catch (error) { - this.#actor.rLog.error({ - msg: "error in `onConnect`", - error: stringifyError(error), - }); - conn?.disconnect("`onConnect` failed"); } } - - #handleUncleanDisconnect(conn: Conn) { - if (!conn[CONN_DRIVER_SYMBOL]) { - this.#actor.rLog.warn("called conn disconnected without driver"); - } - - // Update last seen for cleanup tracking - conn[CONN_PERSIST_SYMBOL].lastSeen = Date.now(); - - // Remove socket - conn[CONN_DRIVER_SYMBOL] = undefined; - } } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts index 944617e002..20c252824b 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts @@ -1,8 +1,16 @@ import * as cbor from "cbor-x"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; +import { + type ToClient as ToClientJson, + ToClientSchema, +} from "@/schemas/client-protocol-zod/mod"; import { bufferToArrayBuffer } from "@/utils"; -import { CONN_PERSIST_SYMBOL, type Conn } from "../conn/mod"; +import { + CONN_PERSIST_SYMBOL, + CONN_SEND_MESSAGE_SYMBOL, + type Conn, +} from "../conn/mod"; import type { AnyDatabaseProvider } from "../database"; import { CachedSerializer } from "../protocol/serde"; import type { ActorInstance } from "./mod"; @@ -157,6 +165,8 @@ export class EventManager { * @param args - The arguments to send with the event */ broadcast>(name: string, ...args: Args) { + this.#actor.assertReady(); + // Emit to inspector this.#actor.inspector.emitter.emit("eventFired", { type: "broadcast", @@ -175,24 +185,38 @@ export class EventManager { } // Create serialized message - const toClientSerializer = new CachedSerializer( - { + const eventData = { name, args }; + const toClientSerializer = new CachedSerializer( + eventData, + TO_CLIENT_VERSIONED, + ToClientSchema, + // JSON: args is the raw value (array of arguments) + (value): ToClientJson => ({ body: { - tag: "Event", + tag: "Event" as const, val: { - name, - args: bufferToArrayBuffer(cbor.encode(args)), + name: value.name, + args: value.args, }, }, - }, - TO_CLIENT_VERSIONED, + }), + // BARE/CBOR: args needs to be CBOR-encoded to ArrayBuffer + (value): protocol.ToClient => ({ + body: { + tag: "Event" as const, + val: { + name: value.name, + args: bufferToArrayBuffer(cbor.encode(value.args)), + }, + }, + }), ); // Send to all subscribers let sentCount = 0; for (const connection of subscribers) { try { - connection.sendMessage(toClientSerializer); + connection[CONN_SEND_MESSAGE_SYMBOL](toClientSerializer); sentCount++; } catch (error) { this.#actor.rLog.error({ diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts index d79a4aa93c..76be9d2bfd 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts @@ -10,13 +10,23 @@ import type { Registry } from "@/mod"; import { ACTOR_VERSIONED } from "@/schemas/actor-persist/versioned"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; +import { ToClientSchema } from "@/schemas/client-protocol-zod/mod"; import { EXTRA_ERROR_LOG, idToStr } from "@/utils"; import type { ActorConfig, InitContext } from "../config"; import type { ConnDriver } from "../conn/driver"; import { createHttpSocket } from "../conn/drivers/http"; -import { CONN_PERSIST_SYMBOL, type Conn, type ConnId } from "../conn/mod"; +import { + CONN_DRIVER_SYMBOL, + CONN_PERSIST_SYMBOL, + CONN_SEND_MESSAGE_SYMBOL, + CONN_STATE_ENABLED_SYMBOL, + type Conn, + type ConnId, +} from "../conn/mod"; import { ActionContext } from "../contexts/action"; import { ActorContext } from "../contexts/actor"; +import { RequestContext } from "../contexts/request"; +import { WebSocketContext } from "../contexts/websocket"; import type { AnyDatabaseProvider, InferDatabaseClient } from "../database"; import type { ActorDriver } from "../driver"; import * as errors from "../errors"; @@ -34,8 +44,6 @@ import { type SaveStateOptions, StateManager } from "./state-manager"; export type { SaveStateOptions }; -export const ACTOR_INSTANCE_PERSIST_SYMBOL = Symbol("persist"); - enum CanSleep { Yes, NotReady, @@ -66,7 +74,7 @@ export class ActorInstance { // MARK: - Core Properties actorContext: ActorContext; #config: ActorConfig; - #actorDriver!: ActorDriver; + driver!: ActorDriver; #inlineClient!: Client>; #actorId!: string; #name!: string; @@ -74,9 +82,12 @@ export class ActorInstance { #region!: string; // MARK: - Managers - #connectionManager!: ConnectionManager; + connectionManager!: ConnectionManager; + #stateManager!: StateManager; - #eventManager!: EventManager; + + eventManager!: EventManager; + #scheduleManager!: ScheduleManager; // MARK: - Logging @@ -130,14 +141,17 @@ export class ActorInstance { }, getConnections: async () => { return Array.from( - this.#connectionManager.connections.entries(), + this.connectionManager.connections.entries(), ).map(([id, conn]) => ({ + type: conn[CONN_DRIVER_SYMBOL]?.type, id, params: conn.params as any, - state: conn.stateEnabled ? conn.state : undefined, + state: conn[CONN_STATE_ENABLED_SYMBOL] + ? conn.state + : undefined, subscriptions: conn.subscriptions.size, lastSeen: conn.lastSeen, - stateEnabled: conn.stateEnabled, + stateEnabled: conn[CONN_STATE_ENABLED_SYMBOL], isHibernatable: conn.isHibernatable, hibernatableRequestId: conn[CONN_PERSIST_SYMBOL] .hibernatableRequestId @@ -155,20 +169,21 @@ export class ActorInstance { await this.#stateManager.saveState({ immediate: true }); }, executeAction: async (name, params) => { - const conn = await this.createConn( + const conn = await this.connectionManager.prepareAndConnectConn( createHttpSocket(), - undefined, + // TODO: This may cause issues + undefined as unknown as CP, undefined, ); try { return await this.executeAction( - new ActionContext(this.actorContext, conn), + new ActionContext(this, conn), name, params || [], ); } finally { - this.connDisconnected(conn, true); + conn.disconnect(); } }, }; @@ -224,7 +239,7 @@ export class ActorInstance { } get conns(): Map> { - return this.#connectionManager.connections; + return this.connectionManager.connections; } get schedule(): Schedule { @@ -244,7 +259,7 @@ export class ActorInstance { } // MARK: - State Access - get [ACTOR_INSTANCE_PERSIST_SYMBOL](): PersistedActor { + get persist(): PersistedActor { return this.#stateManager.persist; } @@ -288,7 +303,7 @@ export class ActorInstance { region: string, ) { // Initialize properties - this.#actorDriver = actorDriver; + this.driver = actorDriver; this.#inlineClient = inlineClient; this.#actorId = actorId; this.#name = name; @@ -299,9 +314,9 @@ export class ActorInstance { this.#initializeLogging(); // Initialize managers - this.#connectionManager = new ConnectionManager(this); + this.connectionManager = new ConnectionManager(this); this.#stateManager = new StateManager(this, actorDriver, this.#config); - this.#eventManager = new EventManager(this); + this.eventManager = new EventManager(this); this.#scheduleManager = new ScheduleManager( this, actorDriver, @@ -336,7 +351,7 @@ export class ActorInstance { this.#rLog.info({ msg: "actor ready" }); // Start sleep timer - this.#resetSleepTimer(); + this.resetSleepTimer(); // Trigger any pending alarms await this.onAlarm(); @@ -347,7 +362,7 @@ export class ActorInstance { return this.#ready; } - #assertReady(allowStoppingState: boolean = false) { + assertReady(allowStoppingState: boolean = false) { if (!this.#ready) throw new errors.InternalError("Actor not ready"); if (!allowStoppingState && this.#stopCalled) throw new errors.InternalError("Actor is stopping"); @@ -410,10 +425,7 @@ export class ActorInstance { } this.#sleepCalled = true; - const sleep = this.#actorDriver.startSleep?.bind( - this.#actorDriver, - this.#actorId, - ); + const sleep = this.driver.startSleep?.bind(this.driver, this.#actorId); invariant(this.#sleepingSupported, "sleeping not supported"); invariant(sleep, "no sleep on driver"); @@ -427,7 +439,7 @@ export class ActorInstance { // MARK: - HTTP Request Tracking beginHonoHttpRequest() { this.#activeHonoHttpRequests++; - this.#resetSleepTimer(); + this.resetSleepTimer(); } endHonoHttpRequest() { @@ -439,86 +451,39 @@ export class ActorInstance { ...EXTRA_ERROR_LOG, }); } - this.#resetSleepTimer(); + this.resetSleepTimer(); } // MARK: - State Management async saveState(opts: SaveStateOptions) { - this.#assertReady(opts.allowStoppingState); + this.assertReady(opts.allowStoppingState); // Save state through StateManager await this.#stateManager.saveState(opts); // Save connection changes - if (this.#connectionManager.changedConnections.size > 0) { - const entries = this.#connectionManager.getChangedConnectionsData(); + if (this.connectionManager.changedConnections.size > 0) { + const entries = this.connectionManager.getChangedConnectionsData(); if (entries.length > 0) { - await this.#actorDriver.kvBatchPut(this.#actorId, entries); + await this.driver.kvBatchPut(this.#actorId, entries); } - this.#connectionManager.clearChangedConnections(); + this.connectionManager.clearChangedConnections(); } } - // MARK: - Connection Management - getConnForId(id: string): Conn | undefined { - return this.#connectionManager.getConnForId(id); - } - - markConnChanged(conn: Conn) { - this.#connectionManager.markConnChanged(conn); - } - - connDisconnected(conn: Conn, wasClean: boolean) { - this.#connectionManager.connDisconnected( - conn, - wasClean, - this.#actorDriver, - this.#eventManager, - ); - this.#resetSleepTimer(); - } - - async createConn( - driver: ConnDriver, - params: any, - request?: Request, - ): Promise> { - this.#assertReady(); - - const conn = await this.#connectionManager.createConn( - driver, - params, - request, - ); - - // Reset sleep timer after connection - this.#resetSleepTimer(); - - // Save state immediately - await this.saveState({ immediate: true }); - - // Send init message - conn.sendMessage( - new CachedSerializer( - { - body: { - tag: "Init", - val: { - actorId: this.id, - connectionId: conn.id, - }, - }, - }, - TO_CLIENT_VERSIONED, - ), - ); - - return conn; - } - // MARK: - Message Processing async processMessage( - message: protocol.ToServer, + message: { + body: + | { + tag: "ActionRequest"; + val: { id: bigint; name: string; args: unknown }; + } + | { + tag: "SubscriptionRequest"; + val: { eventName: string; subscribe: boolean }; + }; + }, conn: Conn, ) { await processMessage(message, this, conn, { @@ -537,7 +502,7 @@ export class ActorInstance { eventName, connId: conn.id, }); - this.#eventManager.addSubscription(eventName, conn, false); + this.eventManager.addSubscription(eventName, conn, false); }, onUnsubscribe: async (eventName, conn) => { this.inspector.emitter.emit("eventFired", { @@ -545,7 +510,7 @@ export class ActorInstance { eventName, connId: conn.id, }); - this.#eventManager.removeSubscription(eventName, conn, false); + this.eventManager.removeSubscription(eventName, conn, false); }, }); } @@ -635,29 +600,26 @@ export class ActorInstance { } // MARK: - HTTP/WebSocket Handlers - async handleFetch( + async handleRawRequest( + conn: Conn, request: Request, - opts: Record, ): Promise { - this.#assertReady(); + this.assertReady(); - if (!this.#config.onFetch) { - throw new errors.FetchHandlerNotDefined(); + if (!this.#config.onRequest) { + throw new errors.RequestHandlerNotDfeined(); } try { - const response = await this.#config.onFetch( - this.actorContext, - request, - opts, - ); + const ctx = new RequestContext(this, conn, request); + const response = await this.#config.onRequest(ctx, request); if (!response) { - throw new errors.InvalidFetchResponse(); + throw new errors.InvalidRequestHandlerResponse(); } return response; } catch (error) { this.#rLog.error({ - msg: "onFetch error", + msg: "onRequest error", error: stringifyError(error), }); throw error; @@ -666,28 +628,36 @@ export class ActorInstance { } } - async handleWebSocket( + handleRawWebSocket( + conn: Conn, websocket: UniversalWebSocket, - opts: { request: Request }, - ): Promise { - this.#assertReady(); + request?: Request, + ) { + // NOTE: All code before `onWebSocket` must be synchronous in order to ensure the order of `open` events happen in the correct order. + + this.assertReady(); if (!this.#config.onWebSocket) { throw new errors.InternalError("onWebSocket handler not defined"); } try { - const stateBeforeHandler = this.#stateManager.persistChanged; - // Reset sleep timer when handling WebSocket - this.#resetSleepTimer(); + this.resetSleepTimer(); // Handle WebSocket - await this.#config.onWebSocket(this.actorContext, websocket, opts); + const ctx = new WebSocketContext(this, conn, request); + + // NOTE: This is async and will run in the background + const voidOrPromise = this.#config.onWebSocket(ctx, websocket); - // Save state if changed - if (this.#stateManager.persistChanged && !stateBeforeHandler) { - await this.saveState({ immediate: true }); + // Save changes from the WebSocket open + if (voidOrPromise instanceof Promise) { + voidOrPromise.then(() => { + this.#stateManager.savePersistThrottled(); + }); + } else { + this.#stateManager.savePersistThrottled(); } } catch (error) { this.#rLog.error({ @@ -695,17 +665,9 @@ export class ActorInstance { error: stringifyError(error), }); throw error; - } finally { - this.#stateManager.savePersistThrottled(); } } - // MARK: - Event Broadcasting - broadcast>(name: string, ...args: Args) { - this.#assertReady(); - this.#eventManager.broadcast(name, ...args); - } - // MARK: - Scheduling async scheduleEvent( timestamp: number, @@ -716,13 +678,13 @@ export class ActorInstance { } async onAlarm() { - this.#resetSleepTimer(); + this.resetSleepTimer(); await this.#scheduleManager.onAlarm(); } // MARK: - Background Tasks waitUntil(promise: Promise) { - this.#assertReady(); + this.assertReady(); const nonfailablePromise = promise .then(() => { @@ -745,7 +707,7 @@ export class ActorInstance { actorId: this.#actorId, }; - const extraLogParams = this.#actorDriver.getExtraActorLogParams?.(); + const extraLogParams = this.driver.getExtraActorLogParams?.(); if (extraLogParams) Object.assign(logParams, extraLogParams); this.#log = getBaseLogger().child( @@ -764,7 +726,7 @@ export class ActorInstance { async #initializeState() { // Read initial state from KV - const [persistDataBuffer] = await this.#actorDriver.kvBatchGet( + const [persistDataBuffer] = await this.driver.kvBatchGet( this.#actorId, [KEYS.PERSIST_DATA], ); @@ -792,7 +754,7 @@ export class ActorInstance { async #restoreExistingActor(persistData: PersistedActor) { // List all connection keys - const connEntries = await this.#actorDriver.kvListPrefix( + const connEntries = await this.driver.kvListPrefix( this.#actorId, KEYS.CONN_PREFIX, ); @@ -821,10 +783,7 @@ export class ActorInstance { this.#stateManager.initPersistProxy(persistData); // Restore connections - this.#connectionManager.restoreConnections( - connections, - this.#eventManager, - ); + this.connectionManager.restoreConnections(connections); } async #createNewActor(persistData: PersistedActor) { @@ -841,10 +800,9 @@ export class ActorInstance { async #initializeInspectorToken() { // Try to load existing token - const [tokenBuffer] = await this.#actorDriver.kvBatchGet( - this.#actorId, - [KEYS.INSPECTOR_TOKEN], - ); + const [tokenBuffer] = await this.driver.kvBatchGet(this.#actorId, [ + KEYS.INSPECTOR_TOKEN, + ]); if (tokenBuffer !== null) { // Token exists, decode it @@ -855,7 +813,7 @@ export class ActorInstance { // Generate new token this.#inspectorToken = generateSecureToken(); const tokenBytes = new TextEncoder().encode(this.#inspectorToken); - await this.#actorDriver.kvBatchPut(this.#actorId, [ + await this.driver.kvBatchPut(this.#actorId, [ [KEYS.INSPECTOR_TOKEN, tokenBytes], ]); this.#rLog.debug({ msg: "generated new inspector token" }); @@ -867,7 +825,7 @@ export class ActorInstance { if ("createVars" in this.#config) { const dataOrPromise = this.#config.createVars( this.actorContext as unknown as InitContext, - this.#actorDriver.getContext(this.#actorId), + this.driver.getContext(this.#actorId), ); if (dataOrPromise instanceof Promise) { vars = await deadline( @@ -922,7 +880,7 @@ export class ActorInstance { async #setupDatabase() { if ("db" in this.#config && this.#config.db) { const client = await this.#config.db.createClient({ - getDatabase: () => this.#actorDriver.getDatabase(this.#actorId), + getDatabase: () => this.driver.getDatabase(this.#actorId), }); this.#rLog.info({ msg: "database migration starting" }); await this.#config.db.onMigrate?.(client); @@ -933,7 +891,7 @@ export class ActorInstance { async #disconnectConnections() { const promises: Promise[] = []; - for (const connection of this.#connectionManager.connections.values()) { + for (const connection of this.connectionManager.connections.values()) { if (!connection.isHibernatable) { this.#rLog.debug({ msg: "disconnecting non-hibernatable connection on actor stop", @@ -983,7 +941,7 @@ export class ActorInstance { } } - #resetSleepTimer() { + resetSleepTimer() { if (this.#config.options.noSleep || !this.#sleepingSupported) return; if (this.#stopCalled) return; @@ -1015,7 +973,7 @@ export class ActorInstance { if (this.#activeHonoHttpRequests > 0) return CanSleep.ActiveHonoHttpRequests; - for (const _conn of this.#connectionManager.connections.values()) { + for (const _conn of this.connectionManager.connections.values()) { return CanSleep.ActiveConns; } @@ -1023,7 +981,7 @@ export class ActorInstance { } get #sleepingSupported(): boolean { - return this.#actorDriver.startSleep !== undefined; + return this.driver.startSleep !== undefined; } get #varsEnabled(): boolean { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts index 4a8c010b74..686a1a275e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts @@ -76,6 +76,12 @@ export type * from "./config"; export type { Conn } from "./conn/mod"; export type { ActionContext } from "./contexts/action"; export type { ActorContext } from "./contexts/actor"; +export type { ConnInitContext } from "./contexts/conn-init"; +export type { CreateConnStateContext } from "./contexts/create-conn-state"; +export type { OnBeforeConnectContext } from "./contexts/on-before-connect"; +export type { OnConnectContext } from "./contexts/on-connect"; +export type { RequestContext } from "./contexts/request"; +export type { WebSocketContext } from "./contexts/websocket"; export type { ActionContextOf, ActorContextOf, @@ -90,6 +96,6 @@ export { createActorRouter, } from "./router"; export { - handleRawWebSocketHandler, + handleRawWebSocket as handleRawWebSocketHandler, handleWebSocketConnect, } from "./router-endpoints"; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts index 0cd53e94cb..c830cd86f3 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts @@ -13,9 +13,15 @@ import { TO_CLIENT_VERSIONED, TO_SERVER_VERSIONED, } from "@/schemas/client-protocol/versioned"; +import { + type ToClient as ToClientJson, + ToClientSchema, + type ToServer as ToServerJson, + ToServerSchema, +} from "@/schemas/client-protocol-zod/mod"; import { deserializeWithEncoding } from "@/serde"; import { assertUnreachable, bufferToArrayBuffer } from "../../utils"; -import type { Conn } from "../conn/mod"; +import { CONN_SEND_MESSAGE_SYMBOL, type Conn } from "../conn/mod"; import { ActionContext } from "../contexts/action"; import type { ActorInstance } from "../instance/mod"; @@ -63,7 +69,17 @@ export async function inputDataToBuffer( export async function parseMessage( value: InputData, opts: MessageEventOpts, -): Promise { +): Promise<{ + body: + | { + tag: "ActionRequest"; + val: { id: bigint; name: string; args: unknown }; + } + | { + tag: "SubscriptionRequest"; + val: { eventName: string; subscribe: boolean }; + }; +}> { // Validate value length const length = getValueLength(value); if (length > opts.maxIncomingMessageSize) { @@ -81,7 +97,34 @@ export async function parseMessage( } // Deserialize message - return deserializeWithEncoding(opts.encoding, buffer, TO_SERVER_VERSIONED); + return deserializeWithEncoding( + opts.encoding, + buffer, + TO_SERVER_VERSIONED, + ToServerSchema, + // JSON: values are already the correct type + (json: ToServerJson): any => json, + // BARE: need to decode ArrayBuffer fields back to unknown + (bare: protocol.ToServer): any => { + if (bare.body.tag === "ActionRequest") { + return { + body: { + tag: "ActionRequest", + val: { + id: bare.body.val.id, + name: bare.body.val.name, + args: cbor.decode( + new Uint8Array(bare.body.val.args), + ), + }, + }, + }; + } else { + // SubscriptionRequest has no ArrayBuffer fields + return bare; + } + }, + ); } export interface ProcessMessageHandler< @@ -115,7 +158,17 @@ export async function processMessage< I, DB extends AnyDatabaseProvider, >( - message: protocol.ToServer, + message: { + body: + | { + tag: "ActionRequest"; + val: { id: bigint; name: string; args: unknown }; + } + | { + tag: "SubscriptionRequest"; + val: { eventName: string; subscribe: boolean }; + }; + }, actor: ActorInstance, conn: Conn, handler: ProcessMessageHandler, @@ -131,10 +184,9 @@ export async function processMessage< throw new errors.Unsupported("Action"); } - const { id, name, args: argsRaw } = message.body.val; + const { id, name, args } = message.body.val; actionId = id; actionName = name; - const args = cbor.decode(new Uint8Array(argsRaw)); actor.rLog.debug({ msg: "processing action request", @@ -142,14 +194,15 @@ export async function processMessage< actionName: name, }); - const ctx = new ActionContext( - actor.actorContext, - conn, - ); + const ctx = new ActionContext(actor, conn); // Process the action request and wait for the result // This will wait for async actions to complete - const output = await handler.onExecuteAction(ctx, name, args); + const output = await handler.onExecuteAction( + ctx, + name, + args as unknown[], + ); actor.rLog.debug({ msg: "sending action response", @@ -160,20 +213,31 @@ export async function processMessage< }); // Send the response back to the client - conn.sendMessage( - new CachedSerializer( - { + conn[CONN_SEND_MESSAGE_SYMBOL]( + new CachedSerializer( + output, + TO_CLIENT_VERSIONED, + ToClientSchema, + // JSON: output is the raw value + (value): ToClientJson => ({ body: { - tag: "ActionResponse", + tag: "ActionResponse" as const, val: { id: id, - output: bufferToArrayBuffer( - cbor.encode(output), - ), + output: value, }, }, - }, - TO_CLIENT_VERSIONED, + }), + // BARE/CBOR: output needs to be CBOR-encoded to ArrayBuffer + (value): protocol.ToClient => ({ + body: { + tag: "ActionResponse" as const, + val: { + id: id, + output: bufferToArrayBuffer(cbor.encode(value)), + }, + }, + }), ), ); @@ -229,23 +293,54 @@ export async function processMessage< }); // Build response - conn.sendMessage( - new CachedSerializer( - { + const errorData = { group, code, message, metadata, actionId }; + conn[CONN_SEND_MESSAGE_SYMBOL]( + new CachedSerializer( + errorData, + TO_CLIENT_VERSIONED, + ToClientSchema, + // JSON: metadata is the raw value (keep as undefined if not present) + (value): ToClientJson => { + const val: any = { + group: value.group, + code: value.code, + message: value.message, + actionId: + value.actionId !== undefined + ? value.actionId + : null, + }; + if (value.metadata !== undefined) { + val.metadata = value.metadata; + } + return { + body: { + tag: "Error" as const, + val, + }, + }; + }, + // BARE/CBOR: metadata needs to be CBOR-encoded to ArrayBuffer + // Note: protocol.Error expects `| null` for optional fields (BARE protocol) + (value): protocol.ToClient => ({ body: { - tag: "Error", + tag: "Error" as const, val: { - group, - code, - message, - metadata: bufferToArrayBuffer( - cbor.encode(metadata), - ), - actionId: actionId ?? null, + group: value.group, + code: value.code, + message: value.message, + metadata: value.metadata + ? bufferToArrayBuffer( + cbor.encode(value.metadata), + ) + : null, + actionId: + value.actionId !== undefined + ? value.actionId + : null, }, }, - }, - TO_CLIENT_VERSIONED, + }), ), ); diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/serde.ts b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/serde.ts index 1c84128f88..4e4ed85432 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/serde.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/serde.ts @@ -22,14 +22,26 @@ export type Encoding = z.infer; /** * Helper class that helps serialize data without re-serializing for the same encoding. */ -export class CachedSerializer { +export class CachedSerializer { #data: T; #cache = new Map(); - #versionedDataHandler: VersionedDataHandler; - - constructor(data: T, versionedDataHandler: VersionedDataHandler) { + #versionedDataHandler: VersionedDataHandler; + #zodSchema: z.ZodType; + #toJson: (value: T) => TJson; + #toBare: (value: T) => TBare; + + constructor( + data: T, + versionedDataHandler: VersionedDataHandler, + zodSchema: z.ZodType, + toJson: (value: T) => TJson, + toBare: (value: T) => TBare, + ) { this.#data = data; this.#versionedDataHandler = versionedDataHandler; + this.#zodSchema = zodSchema; + this.#toJson = toJson; + this.#toBare = toBare; } public get rawData(): T { @@ -45,6 +57,9 @@ export class CachedSerializer { encoding, this.#data, this.#versionedDataHandler, + this.#zodSchema, + this.#toJson, + this.#toBare, ); this.#cache.set(encoding, serialized); return serialized; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts index 1a8885d6cb..e593a5150e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts @@ -4,16 +4,15 @@ import type { WSContext } from "hono/ws"; import type { AnyConn } from "@/actor/conn/mod"; import { ActionContext } from "@/actor/contexts/action"; import * as errors from "@/actor/errors"; -import { - ACTOR_INSTANCE_PERSIST_SYMBOL, - type AnyActorInstance, -} from "@/actor/instance/mod"; +import type { AnyActorInstance } from "@/actor/instance/mod"; import type { InputData } from "@/actor/protocol/serde"; import { type Encoding, EncodingSchema } from "@/actor/protocol/serde"; import { HEADER_ACTOR_QUERY, HEADER_CONN_PARAMS, HEADER_ENCODING, + WS_PROTOCOL_CONN_PARAMS, + WS_PROTOCOL_ENCODING, } from "@/common/actor-router-consts"; import type { UpgradeWebSocketArgs } from "@/common/inline-websocket-adapter2"; import { deconstructError, stringifyError } from "@/common/utils"; @@ -25,6 +24,12 @@ import { HTTP_ACTION_REQUEST_VERSIONED, HTTP_ACTION_RESPONSE_VERSIONED, } from "@/schemas/client-protocol/versioned"; +import { + type HttpActionRequest as HttpActionRequestJson, + HttpActionRequestSchema, + type HttpActionResponse as HttpActionResponseJson, + HttpActionResponseSchema, +} from "@/schemas/client-protocol-zod/mod"; import { contentTypeForEncoding, deserializeWithEncoding, @@ -33,9 +38,11 @@ import { import { arrayBuffersEqual, bufferToArrayBuffer, + idToStr, promiseWithResolvers, } from "@/utils"; import { createHttpSocket } from "./conn/drivers/http"; +import { createRawRequestSocket } from "./conn/drivers/raw-request"; import { createRawWebSocketSocket } from "./conn/drivers/raw-websocket"; import { createWebSocketSocket } from "./conn/drivers/websocket"; import type { ActorDriver } from "./driver"; @@ -102,136 +109,59 @@ export async function handleWebSocketConnect( ? getRequestExposeInternalError(req) : false; - // Setup promise for the init handlers since all other behavior depends on this - const { - promise: handlersPromise, - resolve: handlersResolve, - reject: handlersReject, - } = promiseWithResolvers<{ - conn: AnyConn; - actor: AnyActorInstance; - connId: string; - }>(); - - // Pre-load the actor to catch errors early - let actor: AnyActorInstance; + let createdConn: AnyConn | undefined; try { - actor = await actorDriver.loadActor(actorId); - } catch (error) { - // Return handler that immediately closes with error + const actor = await actorDriver.loadActor(actorId); + + // Promise used to wait for the websocket close in `disconnect` + const closePromiseResolvers = promiseWithResolvers(); + + actor.rLog.debug({ + msg: "new websocket connection", + actorId, + }); + + // Check if this is a hibernatable websocket + const isHibernatable = + !!requestIdBuf && + actor.persist.hibernatableConns.findIndex((conn) => + arrayBuffersEqual(conn.hibernatableRequestId, requestIdBuf), + ) !== -1; + + const { driver, setWebSocket } = createWebSocketSocket( + requestId, + requestIdBuf, + isHibernatable, + encoding, + closePromiseResolvers.promise, + ); + const conn = await actor.connectionManager.prepareConn( + driver, + parameters, + req, + ); + createdConn = conn; + return { + // NOTE: onOpen cannot be async since this messes up the open event listener order onOpen: (_evt: any, ws: WSContext) => { - const { code } = deconstructError( - error, - actor.rLog, - { - wsEvent: "open", - }, - exposeInternalError, - ); - ws.close(1011, code); - }, - onMessage: (_evt: { data: any }, ws: WSContext) => { - ws.close(1011, "Actor not loaded"); - }, - onClose: (_event: any, _ws: WSContext) => {}, - onError: (_error: unknown) => {}, - }; - } + actor.rLog.debug("actor websocket open"); - // Promise used to wait for the websocket close in `disconnect` - const closePromiseResolvers = promiseWithResolvers(); + setWebSocket(ws); - // Track connection outside of scope for cleanup - let createdConn: AnyConn | undefined; - - return { - onOpen: (_evt: any, ws: WSContext) => { - actor.rLog.debug("actor websocket open"); - - // Run async operations in background - (async () => { - try { - let conn: AnyConn; - - actor.rLog.debug({ - msg: "new websocket connection", - actorId, - }); - - // Check if this is a hibernatable websocket - const isHibernatable = - !!requestIdBuf && - actor[ - ACTOR_INSTANCE_PERSIST_SYMBOL - ].hibernatableConns.findIndex((conn) => - arrayBuffersEqual( - conn.hibernatableRequestId, - requestIdBuf, - ), - ) !== -1; - - conn = await actor.createConn( - createWebSocketSocket( - requestId, - requestIdBuf, - isHibernatable, - encoding, - ws, - closePromiseResolvers.promise, - ), - parameters, - req, - ); - - // Store connection so we can clean on close - createdConn = conn; - - // Unblock other handlers - handlersResolve({ conn, actor, connId: conn.id }); - } catch (error) { - handlersReject(error); - - const { code } = deconstructError( - error, - actor.rLog, - { - wsEvent: "open", - }, - exposeInternalError, - ); - ws.close(1011, code); - } - })(); - }, - onMessage: (evt: { data: any }, ws: WSContext) => { - // Handle message asynchronously - handlersPromise - .then(({ conn, actor }) => { - actor.rLog.debug({ msg: "received message" }); - - const value = evt.data.valueOf() as InputData; - parseMessage(value, { - encoding: encoding, - maxIncomingMessageSize: - runConfig.maxIncomingMessageSize, - }) - .then((message) => { - actor - .processMessage(message, conn) - .catch((error) => { - const { code } = deconstructError( - error, - actor.rLog, - { - wsEvent: "message", - }, - exposeInternalError, - ); - ws.close(1011, code); - }); - }) - .catch((error) => { + actor.connectionManager.connectConn(conn); + }, + onMessage: (evt: { data: any }, ws: WSContext) => { + // Handle message asynchronously + actor.rLog.debug({ msg: "received message" }); + + const value = evt.data.valueOf() as InputData; + parseMessage(value, { + encoding: encoding, + maxIncomingMessageSize: runConfig.maxIncomingMessageSize, + }) + .then((message) => { + actor.processMessage(message, conn).catch((error) => { const { code } = deconstructError( error, actor.rLog, @@ -242,73 +172,93 @@ export async function handleWebSocketConnect( ); ws.close(1011, code); }); - }) - .catch((error) => { - const { code } = deconstructError( + }) + .catch((error) => { + const { code } = deconstructError( + error, + actor.rLog, + { + wsEvent: "message", + }, + exposeInternalError, + ); + ws.close(1011, code); + }); + }, + onClose: ( + event: { + wasClean: boolean; + code: number; + reason: string; + }, + ws: WSContext, + ) => { + closePromiseResolvers.resolve(); + + if (event.wasClean) { + actor.rLog.info({ + msg: "websocket closed", + code: event.code, + reason: event.reason, + wasClean: event.wasClean, + }); + } else { + actor.rLog.warn({ + msg: "websocket closed", + code: event.code, + reason: event.reason, + wasClean: event.wasClean, + }); + } + + // HACK: Close socket in order to fix bug with Cloudflare leaving WS in closing state + // https://github.com/cloudflare/workerd/issues/2569 + ws.close(1000, "hack_force_close"); + + // Wait for actor.createConn to finish before removing the connection + if (createdConn) { + createdConn.disconnect(event?.reason); + } + }, + onError: (_error: unknown) => { + try { + // Actors don't need to know about this, since it's abstracted away + actor.rLog.warn({ msg: "websocket error" }); + } catch (error) { + deconstructError( error, actor.rLog, - { - wsEvent: "message", - }, + { wsEvent: "error" }, exposeInternalError, ); - ws.close(1011, code); - }); - }, - onClose: ( - event: { - wasClean: boolean; - code: number; - reason: string; + } }, - ws: WSContext, - ) => { - handlersReject(`WebSocket closed (${event.code}): ${event.reason}`); - - closePromiseResolvers.resolve(); - - if (event.wasClean) { - actor.rLog.info({ - msg: "websocket closed", - code: event.code, - reason: event.reason, - wasClean: event.wasClean, - }); - } else { - actor.rLog.warn({ - msg: "websocket closed", - code: event.code, - reason: event.reason, - wasClean: event.wasClean, - }); - } + }; + } catch (error) { + const { group, code } = deconstructError( + error, + loggerWithoutContext(), + {}, + exposeInternalError, + ); - // HACK: Close socket in order to fix bug with Cloudflare leaving WS in closing state - // https://github.com/cloudflare/workerd/issues/2569 - ws.close(1000, "hack_force_close"); + // Clean up connection + if (createdConn) { + createdConn.disconnect(`${group}.${code}`); + } - // Wait for actor.createConn to finish before removing the connection - handlersPromise.finally(() => { - if (createdConn) { - const wasClean = event.wasClean || event.code === 1000; - actor.connDisconnected(createdConn, wasClean); - } - }); - }, - onError: (_error: unknown) => { - try { - // Actors don't need to know about this, since it's abstracted away - actor.rLog.warn({ msg: "websocket error" }); - } catch (error) { - deconstructError( - error, - actor.rLog, - { wsEvent: "error" }, - exposeInternalError, - ); - } - }, - }; + // Return handler that immediately closes with error + return { + onOpen: (_evt: any, ws: WSContext) => { + ws.close(1011, code); + }, + onMessage: (_evt: { data: any }, ws: WSContext) => { + ws.close(1011, "Actor not loaded"); + }, + onClose: (_event: any, _ws: WSContext) => {}, + onError: (_error: unknown) => {}, + }; + } } /** @@ -330,8 +280,14 @@ export async function handleAction( encoding, new Uint8Array(arrayBuffer), HTTP_ACTION_REQUEST_VERSIONED, + HttpActionRequestSchema, + // JSON: args is already the decoded value (raw object/array) + (json: HttpActionRequestJson) => json.args, + // BARE/CBOR: args is ArrayBuffer that needs CBOR-decoding + (bare: protocol.HttpActionRequest) => + cbor.decode(new Uint8Array(bare.args)), ); - const actionArgs = cbor.decode(new Uint8Array(request.args)); + const actionArgs = request; // Invoke the action let actor: AnyActorInstance | undefined; @@ -343,30 +299,33 @@ export async function handleAction( actor.rLog.debug({ msg: "handling action", actionName, encoding }); // Create conn - conn = await actor.createConn( + conn = await actor.connectionManager.prepareAndConnectConn( createHttpSocket(), parameters, c.req.raw, ); // Call action - const ctx = new ActionContext(actor.actorContext!, conn!); + const ctx = new ActionContext(actor, conn!); output = await actor.executeAction(ctx, actionName, actionArgs); } finally { if (conn) { - // HTTP connections don't have persistent sockets, so no socket ID needed - actor?.connDisconnected(conn, true); + conn.disconnect(); } } // Send response - const responseData: protocol.HttpActionResponse = { - output: bufferToArrayBuffer(cbor.encode(output)), - }; const serialized = serializeWithEncoding( encoding, - responseData, + output, HTTP_ACTION_RESPONSE_VERSIONED, + HttpActionResponseSchema, + // JSON: output is the raw value (will be serialized by jsonStringifyCompat) + (value): HttpActionResponseJson => ({ output: value }), + // BARE/CBOR: output needs to be CBOR-encoded to ArrayBuffer + (value): protocol.HttpActionResponse => ({ + output: bufferToArrayBuffer(cbor.encode(value)), + }), ); // TODO: Remvoe any, Hono is being a dumbass @@ -375,132 +334,179 @@ export async function handleAction( }); } -export async function handleRawWebSocketHandler( - req: Request | undefined, - path: string, +export async function handleRawRequest( + req: Request, actorDriver: ActorDriver, actorId: string, - requestIdBuf: ArrayBuffer | undefined, -): Promise { +): Promise { const actor = await actorDriver.loadActor(actorId); - // Promise used to wait for the websocket close in `disconnect` - const closePromiseResolvers = promiseWithResolvers(); - // Track connection outside of scope for cleanup let createdConn: AnyConn | undefined; - // Return WebSocket event handlers - return { - onOpen: async (evt: any, ws: any) => { - // Extract rivetRequestId provided by engine runner - const rivetRequestId = evt?.rivetRequestId; - const isHibernatable = - !!rivetRequestId && - actor[ - ACTOR_INSTANCE_PERSIST_SYMBOL - ].hibernatableConns.findIndex((conn) => - arrayBuffersEqual( - conn.hibernatableRequestId, - rivetRequestId, - ), - ) !== -1; - - // Wrap the Hono WebSocket in our adapter - const adapter = new HonoWebSocketAdapter( - ws, - rivetRequestId, - isHibernatable, - ); - - // Store adapter reference on the WebSocket for event handlers - (ws as any).__adapter = adapter; - - const newPath = truncateRawWebSocketPathPrefix(path); - let newRequest: Request; - if (req) { - newRequest = new Request(`http://actor${newPath}`, req); - } else { - newRequest = new Request(`http://actor${newPath}`, { - method: "GET", - }); - } + try { + const conn = await actor.connectionManager.prepareAndConnectConn( + createRawRequestSocket(), + {}, + req, + ); + + createdConn = conn; + + return await actor.handleRawRequest(conn, req); + } finally { + // Clean up the connection after the request completes + if (createdConn) { + createdConn.disconnect(); + } + } +} + +export async function handleRawWebSocket( + req: Request | undefined, + path: string, + actorDriver: ActorDriver, + actorId: string, + requestIdBuf: ArrayBuffer | undefined, +): Promise { + const exposeInternalError = req + ? getRequestExposeInternalError(req) + : false; - actor.rLog.debug({ - msg: "rewriting websocket url", - fromPath: path, - toUrl: newRequest.url, + let createdConn: AnyConn | undefined; + try { + const actor = await actorDriver.loadActor(actorId); + + // Promise used to wait for the websocket close in `disconnect` + const closePromiseResolvers = promiseWithResolvers(); + + // Extract rivetRequestId provided by engine runner + const isHibernatable = + !!requestIdBuf && + actor.persist.hibernatableConns.findIndex((conn) => + arrayBuffersEqual(conn.hibernatableRequestId, requestIdBuf), + ) !== -1; + + const newPath = truncateRawWebSocketPathPrefix(path); + let newRequest: Request; + if (req) { + newRequest = new Request(`http://actor${newPath}`, req); + } else { + newRequest = new Request(`http://actor${newPath}`, { + method: "GET", }); + } - try { - // Create connection using actor.createConn - this handles deduplication for hibernatable connections - const requestId = rivetRequestId - ? String(rivetRequestId) - : crypto.randomUUID(); - const conn = await actor.createConn( - createRawWebSocketSocket( - requestId, - rivetRequestId, - isHibernatable, - adapter, - closePromiseResolvers.promise, - ), - {}, // No parameters for raw WebSocket - newRequest, + actor.rLog.debug({ + msg: "rewriting websocket url", + fromPath: path, + toUrl: newRequest.url, + }); + // Create connection using actor.createConn - this handles deduplication for hibernatable connections + const requestIdStr = requestIdBuf + ? idToStr(requestIdBuf) + : crypto.randomUUID(); + const { driver, setWebSocket } = createRawWebSocketSocket( + requestIdStr, + requestIdBuf, + isHibernatable, + closePromiseResolvers.promise, + ); + const conn = await actor.connectionManager.prepareAndConnectConn( + driver, + {}, + newRequest, + ); + createdConn = conn; + + // Return WebSocket event handlers + return { + // NOTE: onOpen cannot be async since this will cause the client's open + // event to be called before this completes. Do all async work in + // handleRawWebSocket root. + onOpen: (_evt: any, ws: any) => { + // Wrap the Hono WebSocket in our adapter + const adapter = new HonoWebSocketAdapter( + ws, + requestIdBuf, + isHibernatable, ); - createdConn = conn; + // Store adapter reference on the WebSocket for event handlers + (ws as any).__adapter = adapter; + + setWebSocket(adapter); // Call the actor's onWebSocket handler with the adapted WebSocket - actor.handleWebSocket(adapter, { - request: newRequest, - }); - } catch (error) { - actor.rLog.error({ - msg: "failed to create raw WebSocket connection", - error: String(error), - }); - ws.close(1011, "Failed to create connection"); - } - }, - onMessage: (event: any, ws: any) => { - // Find the adapter for this WebSocket - const adapter = (ws as any).__adapter; - if (adapter) { - adapter._handleMessage(event); - } - }, - onClose: (evt: any, ws: any) => { - // Find the adapter for this WebSocket - const adapter = (ws as any).__adapter; - if (adapter) { - adapter._handleClose(evt?.code || 1006, evt?.reason || ""); - } + // + // NOTE: onWebSocket is called inside this function. Make sure + // this is called synchronously within onOpen. + actor.handleRawWebSocket(conn, adapter, newRequest); + }, + onMessage: (event: any, ws: any) => { + // Find the adapter for this WebSocket + const adapter = (ws as any).__adapter; + if (adapter) { + adapter._handleMessage(event); + } + }, + onClose: (evt: any, ws: any) => { + // Find the adapter for this WebSocket + const adapter = (ws as any).__adapter; + if (adapter) { + adapter._handleClose(evt?.code || 1006, evt?.reason || ""); + } - // Resolve the close promise - closePromiseResolvers.resolve(); + // Resolve the close promise + closePromiseResolvers.resolve(); - // Clean up the connection - if (createdConn) { - const wasClean = evt?.wasClean || evt?.code === 1000; - actor.connDisconnected(createdConn, wasClean); - } - }, - onError: (error: any, ws: any) => { - // Find the adapter for this WebSocket - const adapter = (ws as any).__adapter; - if (adapter) { - adapter._handleError(error); - } - }, - }; + // Clean up the connection + if (createdConn) { + createdConn.disconnect(evt?.reason); + } + }, + onError: (error: any, ws: any) => { + // Find the adapter for this WebSocket + const adapter = (ws as any).__adapter; + if (adapter) { + adapter._handleError(error); + } + }, + }; + } catch (error) { + const { group, code } = deconstructError( + error, + loggerWithoutContext(), + {}, + exposeInternalError, + ); + + // Clean up connection + if (createdConn) { + createdConn.disconnect(`${group}.${code}`); + } + + // Return handler that immediately closes with error + return { + onOpen: (_evt: any, ws: WSContext) => { + ws.close(1011, code); + }, + onMessage: (_evt: { data: any }, ws: WSContext) => { + ws.close(1011, "Actor not loaded"); + }, + onClose: (_event: any, _ws: WSContext) => {}, + onError: (_error: unknown) => {}, + }; + } } // Helper to get the connection encoding from a request +// +// Defaults to JSON if not provided so we can support vanilla curl requests easily. export function getRequestEncoding(req: HonoRequest): Encoding { const encodingParam = req.header(HEADER_ENCODING); if (!encodingParam) { - throw new errors.InvalidEncoding("undefined"); + return "json"; } const result = EncodingSchema.safeParse(encodingParam); @@ -550,6 +556,35 @@ export function getRequestConnParams(req: HonoRequest): unknown { } } +/** + * Parse encoding and connection parameters from WebSocket Sec-WebSocket-Protocol header + */ +export function parseWebSocketProtocols(protocols: string | null | undefined): { + encoding: Encoding; + connParams: unknown; +} { + let encodingRaw: string | undefined; + let connParamsRaw: string | undefined; + + if (protocols) { + const protocolList = protocols.split(",").map((p) => p.trim()); + for (const protocol of protocolList) { + if (protocol.startsWith(WS_PROTOCOL_ENCODING)) { + encodingRaw = protocol.substring(WS_PROTOCOL_ENCODING.length); + } else if (protocol.startsWith(WS_PROTOCOL_CONN_PARAMS)) { + connParamsRaw = decodeURIComponent( + protocol.substring(WS_PROTOCOL_CONN_PARAMS.length), + ); + } + } + } + + const encoding = EncodingSchema.parse(encodingRaw); + const connParams = connParamsRaw ? JSON.parse(connParamsRaw) : undefined; + + return { encoding, connParams }; +} + /** * Truncase the PATH_WEBSOCKET_PREFIX path prefix in order to pass a clean * path to the onWebSocket handler. diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router.ts index 305ce03a38..e4c44d8ea3 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router.ts @@ -1,6 +1,5 @@ import { Hono } from "hono"; import invariant from "invariant"; -import { EncodingSchema } from "@/actor/protocol/serde"; import { type ActionOpts, type ActionOutput, @@ -8,14 +7,14 @@ import { type ConnectWebSocketOutput, type ConnsMessageOpts, handleAction, - handleRawWebSocketHandler, + handleRawRequest, + handleRawWebSocket, handleWebSocketConnect, + parseWebSocketProtocols, } from "@/actor/router-endpoints"; import { PATH_CONNECT, PATH_WEBSOCKET_PREFIX, - WS_PROTOCOL_CONN_PARAMS, - WS_PROTOCOL_ENCODING, } from "@/common/actor-router-consts"; import { handleRouteError, @@ -31,7 +30,6 @@ import { isInspectorEnabled, secureInspector } from "@/inspector/utils"; import type { RunnerConfig } from "@/registry/run-config"; import { CONN_DRIVER_SYMBOL, generateConnRequestId } from "./conn/mod"; import type { ActorDriver } from "./driver"; -import { InternalError } from "./errors"; import { loggerWithoutContext } from "./log"; export type { @@ -93,7 +91,7 @@ export function createActorRouter( } const actor = await actorDriver.loadActor(c.env.actorId); - const conn = actor.getConnForId(connId); + const conn = actor.connectionManager.getConnForId(connId); if (!conn) { return c.text(`Connection not found: ${connId}`, 404); @@ -114,34 +112,8 @@ export function createActorRouter( return upgradeWebSocket(async (c) => { // Parse configuration from Sec-WebSocket-Protocol header const protocols = c.req.header("sec-websocket-protocol"); - let encodingRaw: string | undefined; - let connParamsRaw: string | undefined; - - if (protocols) { - const protocolList = protocols - .split(",") - .map((p) => p.trim()); - for (const protocol of protocolList) { - if (protocol.startsWith(WS_PROTOCOL_ENCODING)) { - encodingRaw = protocol.substring( - WS_PROTOCOL_ENCODING.length, - ); - } else if ( - protocol.startsWith(WS_PROTOCOL_CONN_PARAMS) - ) { - connParamsRaw = decodeURIComponent( - protocol.substring( - WS_PROTOCOL_CONN_PARAMS.length, - ), - ); - } - } - } - - const encoding = EncodingSchema.parse(encodingRaw); - const connParams = connParamsRaw - ? JSON.parse(connParamsRaw) - : undefined; + const { encoding, connParams } = + parseWebSocketProtocols(protocols); return await handleWebSocketConnect( c.req.raw, @@ -171,14 +143,11 @@ export function createActorRouter( ); }); - // Raw HTTP endpoints - /request/* router.all("/request/*", async (c) => { - const actor = await actorDriver.loadActor(c.env.actorId); - // TODO: This is not a clean way of doing this since `/http/` might exist mid-path // Strip the /http prefix from the URL to get the original path const url = new URL(c.req.url); - const originalPath = url.pathname.replace(/^\/raw\/http/, "") || "/"; + const originalPath = url.pathname.replace(/^\/request/, "") || "/"; // Create a new request with the corrected URL const correctedUrl = new URL(originalPath + url.search, url.origin); @@ -195,18 +164,13 @@ export function createActorRouter( to: correctedRequest.url, }); - // Call the actor's onFetch handler - it will throw appropriate errors - const response = await actor.handleFetch(correctedRequest, {}); - - // This should never happen now since handleFetch throws errors - if (!response) { - throw new InternalError("handleFetch returned void unexpectedly"); - } - - return response; + return await handleRawRequest( + correctedRequest, + actorDriver, + c.env.actorId, + ); }); - // Raw WebSocket endpoint - /websocket/* router.get(`${PATH_WEBSOCKET_PREFIX}*`, async (c) => { const upgradeWebSocket = runConfig.getUpgradeWebSocket?.(); if (upgradeWebSocket) { @@ -222,7 +186,7 @@ export function createActorRouter( pathWithQuery, }); - return await handleRawWebSocketHandler( + return await handleRawWebSocket( c.req.raw, pathWithQuery, actorDriver, diff --git a/rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts b/rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts index 73f8c16e5f..93041edd9b 100644 --- a/rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts +++ b/rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts @@ -25,6 +25,12 @@ import { TO_CLIENT_VERSIONED, TO_SERVER_VERSIONED, } from "@/schemas/client-protocol/versioned"; +import { + type ToClient as ToClientJson, + ToClientSchema, + type ToServer as ToServerJson, + ToServerSchema, +} from "@/schemas/client-protocol-zod/mod"; import { deserializeWithEncoding, encodingIsBinary, @@ -49,7 +55,7 @@ import { interface ActionInFlight { name: string; - resolve: (response: protocol.ActionResponse) => void; + resolve: (response: { id: bigint; output: unknown }) => void; reject: (error: Error) => void; } @@ -95,7 +101,17 @@ export class ActorConnRaw { #actorId?: string; #connectionId?: string; - #messageQueue: protocol.ToServer[] = []; + #messageQueue: Array<{ + body: + | { + tag: "ActionRequest"; + val: { id: bigint; name: string; args: unknown }; + } + | { + tag: "SubscriptionRequest"; + val: { eventName: string; subscribe: boolean }; + }; + }> = []; #actionsInFlight = new Map(); // biome-ignore lint/suspicious/noExplicitAny: Unknown subscription type @@ -172,8 +188,10 @@ export class ActorConnRaw { const actionId = this.#actionIdCounter; this.#actionIdCounter += 1; - const { promise, resolve, reject } = - promiseWithResolvers(); + const { promise, resolve, reject } = promiseWithResolvers<{ + id: bigint; + output: unknown; + }>(); this.#actionsInFlight.set(actionId, { name: opts.name, resolve, @@ -186,10 +204,10 @@ export class ActorConnRaw { val: { id: BigInt(actionId), name: opts.name, - args: bufferToArrayBuffer(cbor.encode(opts.args)), + args: opts.args, }, }, - } satisfies protocol.ToServer); + }); // TODO: Throw error if disconnect is called @@ -199,7 +217,7 @@ export class ActorConnRaw { `Request ID ${actionId} does not match response ID ${responseId}`, ); - return cbor.decode(new Uint8Array(output)) as Response; + return output as Response; } /** @@ -549,16 +567,15 @@ enc return inFlight; } - #dispatchEvent(event: protocol.Event) { - const { name, args: argsRaw } = event; - const args = cbor.decode(new Uint8Array(argsRaw)); + #dispatchEvent(event: { name: string; args: unknown }) { + const { name, args } = event; const listeners = this.#eventSubscriptions.get(name); if (!listeners) return; // Create a new array to avoid issues with listeners being removed during iteration for (const listener of [...listeners]) { - listener.callback(...args); + listener.callback(...(args as unknown[])); // Remove if this was a one-time listener if (listener.once) { @@ -664,7 +681,20 @@ enc }; } - #sendMessage(message: protocol.ToServer, opts?: SendHttpMessageOpts) { + #sendMessage( + message: { + body: + | { + tag: "ActionRequest"; + val: { id: bigint; name: string; args: unknown }; + } + | { + tag: "SubscriptionRequest"; + val: { eventName: string; subscribe: boolean }; + }; + }, + opts?: SendHttpMessageOpts, + ) { if (this.#disposed) { throw new errors.ActorConnDisposed(); } @@ -693,6 +723,28 @@ enc this.#encoding, message, TO_SERVER_VERSIONED, + ToServerSchema, + // JSON: args is the raw value + (msg): ToServerJson => msg as ToServerJson, + // BARE: args needs to be CBOR-encoded to ArrayBuffer + (msg): protocol.ToServer => { + if (msg.body.tag === "ActionRequest") { + return { + body: { + tag: "ActionRequest", + val: { + id: msg.body.val.id, + name: msg.body.val.name, + args: bufferToArrayBuffer( + cbor.encode(msg.body.val.args), + ), + }, + }, + }; + } else { + return msg as protocol.ToServer; + } + }, ); this.#websocket.send(messageSerialized); logger().trace({ @@ -734,7 +786,22 @@ enc } } - async #parseMessage(data: ConnMessage): Promise { + async #parseMessage(data: ConnMessage): Promise<{ + body: + | { tag: "Init"; val: { actorId: string; connectionId: string } } + | { + tag: "Error"; + val: { + group: string; + code: string; + message: string; + metadata: unknown; + actionId: bigint | null; + }; + } + | { tag: "ActionResponse"; val: { id: bigint; output: unknown } } + | { tag: "Event"; val: { name: string; args: unknown } }; + }> { invariant(this.#websocket, "websocket must be defined"); const buffer = await inputDataToBuffer(data); @@ -743,6 +810,59 @@ enc this.#encoding, buffer, TO_CLIENT_VERSIONED, + ToClientSchema, + // JSON: values are already the correct type + (msg): ToClientJson => msg as ToClientJson, + // BARE: need to decode ArrayBuffer fields back to unknown + (msg): any => { + if (msg.body.tag === "Error") { + return { + body: { + tag: "Error", + val: { + group: msg.body.val.group, + code: msg.body.val.code, + message: msg.body.val.message, + metadata: msg.body.val.metadata + ? cbor.decode( + new Uint8Array( + msg.body.val.metadata, + ), + ) + : null, + actionId: msg.body.val.actionId, + }, + }, + }; + } else if (msg.body.tag === "ActionResponse") { + return { + body: { + tag: "ActionResponse", + val: { + id: msg.body.val.id, + output: cbor.decode( + new Uint8Array(msg.body.val.output), + ), + }, + }, + }; + } else if (msg.body.tag === "Event") { + return { + body: { + tag: "Event", + val: { + name: msg.body.val.name, + args: cbor.decode( + new Uint8Array(msg.body.val.args), + ), + }, + }, + }; + } else { + // Init has no ArrayBuffer fields + return msg; + } + }, ); } diff --git a/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts b/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts index 22eb2e5b66..a0abcc8199 100644 --- a/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts +++ b/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts @@ -15,6 +15,12 @@ import { HTTP_ACTION_REQUEST_VERSIONED, HTTP_ACTION_RESPONSE_VERSIONED, } from "@/schemas/client-protocol/versioned"; +import { + type HttpActionRequest as HttpActionRequestJson, + HttpActionRequestSchema, + type HttpActionResponse as HttpActionResponseJson, + HttpActionResponseSchema, +} from "@/schemas/client-protocol-zod/mod"; import { bufferToArrayBuffer } from "@/utils"; import type { ActorDefinitionActions } from "./actor-common"; import { type ActorConn, ActorConnRaw } from "./actor-conn"; @@ -100,8 +106,12 @@ export class ActorHandleRaw { encoding: this.#encoding, }); const responseData = await sendHttpRequest< - protocol.HttpActionRequest, - protocol.HttpActionResponse + protocol.HttpActionRequest, // Bare type + protocol.HttpActionResponse, // Bare type + HttpActionRequestJson, // Json type + HttpActionResponseJson, // Json type + unknown[], // Request type (the args array) + Response // Response type (the output value) >({ url: `http://actor/action/${encodeURIComponent(opts.name)}`, method: "POST", @@ -111,9 +121,7 @@ export class ActorHandleRaw { ? { [HEADER_CONN_PARAMS]: JSON.stringify(this.#params) } : {}), }, - body: { - args: bufferToArrayBuffer(cbor.encode(opts.args)), - } satisfies protocol.HttpActionRequest, + body: opts.args, encoding: this.#encoding, customFetch: this.#driver.sendRequest.bind( this.#driver, @@ -122,9 +130,24 @@ export class ActorHandleRaw { signal: opts?.signal, requestVersionedDataHandler: HTTP_ACTION_REQUEST_VERSIONED, responseVersionedDataHandler: HTTP_ACTION_RESPONSE_VERSIONED, + requestZodSchema: HttpActionRequestSchema, + responseZodSchema: HttpActionResponseSchema, + // JSON Request: args is the raw value + requestToJson: (args): HttpActionRequestJson => ({ + args, + }), + // BARE Request: args needs to be CBOR-encoded + requestToBare: (args): protocol.HttpActionRequest => ({ + args: bufferToArrayBuffer(cbor.encode(args)), + }), + // JSON Response: output is the raw value + responseFromJson: (json): Response => json.output as Response, + // BARE Response: output is ArrayBuffer that needs CBOR-decoding + responseFromBare: (bare): Response => + cbor.decode(new Uint8Array(bare.output)) as Response, }); - return cbor.decode(new Uint8Array(responseData.output)); + return responseData; } catch (err) { // Standardize to ClientActorError instead of the native backend error const { group, code, message, metadata } = deconstructError( diff --git a/rivetkit-typescript/packages/rivetkit/src/client/utils.ts b/rivetkit-typescript/packages/rivetkit/src/client/utils.ts index 4de17e909a..3c4506a6b8 100644 --- a/rivetkit-typescript/packages/rivetkit/src/client/utils.ts +++ b/rivetkit-typescript/packages/rivetkit/src/client/utils.ts @@ -1,10 +1,15 @@ import * as cbor from "cbor-x"; import invariant from "invariant"; +import type { z } from "zod"; import type { Encoding } from "@/actor/protocol/serde"; import { assertUnreachable } from "@/common/utils"; import type { VersionedDataHandler } from "@/common/versioned-data"; import type { HttpResponseError } from "@/schemas/client-protocol/mod"; import { HTTP_RESPONSE_ERROR_VERSIONED } from "@/schemas/client-protocol/versioned"; +import { + type HttpResponseError as HttpResponseErrorJson, + HttpResponseErrorSchema, +} from "@/schemas/client-protocol-zod/mod"; import { contentTypeForEncoding, deserializeWithEncoding, @@ -33,25 +38,51 @@ export function messageLength(message: WebSocketMessage): number { assertUnreachable(message); } -export interface HttpRequestOpts { +export interface HttpRequestOpts< + RequestBare, + ResponseBare, + RequestJson = RequestBare, + ResponseJson = ResponseBare, + Request = RequestBare, + Response = ResponseBare, +> { method: string; url: string; headers: Record; - body?: RequestBody; + body?: Request; encoding: Encoding; skipParseResponse?: boolean; signal?: AbortSignal; - customFetch?: (req: Request) => Promise; - requestVersionedDataHandler: VersionedDataHandler | undefined; + customFetch?: (req: globalThis.Request) => Promise; + requestVersionedDataHandler: VersionedDataHandler | undefined; responseVersionedDataHandler: - | VersionedDataHandler + | VersionedDataHandler | undefined; + requestZodSchema: z.ZodType; + responseZodSchema: z.ZodType; + requestToJson: (value: Request) => RequestJson; + requestToBare: (value: Request) => RequestBare; + responseFromJson: (value: ResponseJson) => Response; + responseFromBare: (value: ResponseBare) => Response; } export async function sendHttpRequest< - RequestBody = unknown, - ResponseBody = unknown, ->(opts: HttpRequestOpts): Promise { + RequestBare = unknown, + ResponseBare = unknown, + RequestJson = RequestBare, + ResponseJson = ResponseBare, + Request = RequestBare, + Response = ResponseBare, +>( + opts: HttpRequestOpts< + RequestBare, + ResponseBare, + RequestJson, + ResponseJson, + Request, + Response + >, +): Promise { logger().debug({ msg: "sending http request", url: opts.url, @@ -64,19 +95,22 @@ export async function sendHttpRequest< if (opts.method === "POST" || opts.method === "PUT") { invariant(opts.body !== undefined, "missing body"); contentType = contentTypeForEncoding(opts.encoding); - bodyData = serializeWithEncoding( + bodyData = serializeWithEncoding( opts.encoding, opts.body, opts.requestVersionedDataHandler, + opts.requestZodSchema, + opts.requestToJson, + opts.requestToBare, ); } // Send request - let response: Response; + let response: globalThis.Response; try { // Make the HTTP request response = await (opts.customFetch ?? fetch)( - new Request(opts.url, { + new globalThis.Request(opts.url, { method: opts.method, headers: { ...opts.headers, @@ -102,12 +136,29 @@ export async function sendHttpRequest< if (!response.ok) { // Attempt to parse structured data const bufferResponse = await response.arrayBuffer(); - let responseData: HttpResponseError; + let responseData: { + group: string; + code: string; + message: string; + metadata: unknown; + }; try { responseData = deserializeWithEncoding( opts.encoding, new Uint8Array(bufferResponse), HTTP_RESPONSE_ERROR_VERSIONED, + HttpResponseErrorSchema, + // JSON: metadata is already unknown + (json): HttpResponseErrorJson => json as HttpResponseErrorJson, + // BARE: decode ArrayBuffer metadata to unknown + (bare): any => ({ + group: bare.group, + code: bare.code, + message: bare.message, + metadata: bare.metadata + ? cbor.decode(new Uint8Array(bare.metadata)) + : undefined, + }), ); } catch (error) { //logger().warn("failed to cleanly parse error, this is likely because a non-structured response is being served", { @@ -132,35 +183,30 @@ export async function sendHttpRequest< } } - // Decode metadata based on encoding - only binary encodings have CBOR-encoded metadata - let decodedMetadata: unknown; - if (responseData.metadata && encodingIsBinary(opts.encoding)) { - decodedMetadata = cbor.decode( - new Uint8Array(responseData.metadata), - ); - } - // Throw structured error throw new ActorError( responseData.group, responseData.code, responseData.message, - decodedMetadata, + responseData.metadata, ); } // Some requests don't need the success response to be parsed, so this can speed things up if (opts.skipParseResponse) { - return undefined as ResponseBody; + return undefined as Response; } // Parse the response based on encoding try { const buffer = new Uint8Array(await response.arrayBuffer()); - return deserializeWithEncoding( + return deserializeWithEncoding( opts.encoding, buffer, opts.responseVersionedDataHandler, + opts.responseZodSchema, + opts.responseFromJson, + opts.responseFromBare, ); } catch (error) { throw new HttpRequestError(`Failed to parse response: ${error}`, { diff --git a/rivetkit-typescript/packages/rivetkit/src/common/router.ts b/rivetkit-typescript/packages/rivetkit/src/common/router.ts index 64e2f8fd15..530f26ec73 100644 --- a/rivetkit-typescript/packages/rivetkit/src/common/router.ts +++ b/rivetkit-typescript/packages/rivetkit/src/common/router.ts @@ -7,9 +7,12 @@ import { } from "@/actor/router-endpoints"; import { buildActorNames, type RegistryConfig } from "@/registry/config"; import type { RunnerConfig } from "@/registry/run-config"; -import { getEndpoint } from "@/remote-manager-driver/api-utils"; -import { HttpResponseError } from "@/schemas/client-protocol/mod"; +import type * as protocol from "@/schemas/client-protocol/mod"; import { HTTP_RESPONSE_ERROR_VERSIONED } from "@/schemas/client-protocol/versioned"; +import { + type HttpResponseError as HttpResponseErrorJson, + HttpResponseErrorSchema, +} from "@/schemas/client-protocol-zod/mod"; import { encodingIsBinary, serializeWithEncoding } from "@/serde"; import { bufferToArrayBuffer, getEnvUniversal, VERSION } from "@/utils"; import { getLogger, type Logger } from "./log"; @@ -68,18 +71,28 @@ export function handleRouteError(error: unknown, c: HonoContext) { encoding = "json"; } + const errorData = { group, code, message, metadata }; const output = serializeWithEncoding( encoding, - { - group, - code, - message, - // TODO: Cannot serialize non-binary meta since it requires ArrayBuffer atm - metadata: encodingIsBinary(encoding) - ? bufferToArrayBuffer(cbor.encode(metadata)) - : null, - }, + errorData, HTTP_RESPONSE_ERROR_VERSIONED, + HttpResponseErrorSchema, + // JSON: metadata is the raw value (will be serialized by jsonStringifyCompat) + (value): HttpResponseErrorJson => ({ + group: value.group, + code: value.code, + message: value.message, + metadata: value.metadata, + }), + // BARE/CBOR: metadata needs to be CBOR-encoded to ArrayBuffer + (value): protocol.HttpResponseError => ({ + group: value.group, + code: value.code, + message: value.message, + metadata: value.metadata + ? bufferToArrayBuffer(cbor.encode(value.metadata)) + : null, + }), ); // TODO: Remove any @@ -125,12 +138,10 @@ export function handleMetadataRequest( : { normal: {} }, }, actorNames: buildActorNames(registryConfig), - // Do not return client endpoint if default server disabled - clientEndpoint: - runConfig.overrideServerAddress ?? - (runConfig.disableDefaultServer - ? undefined - : getEndpoint(runConfig)), + // If server address is changed, return a different client endpoint. + // Otherwise, return null indicating the client should use the current + // endpoint it's already configured with. + clientEndpoint: runConfig.overrideServerAddress, }; return c.json(response); diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-helpers/mod.ts b/rivetkit-typescript/packages/rivetkit/src/driver-helpers/mod.ts index 7b36b9bc4f..d6ad658acc 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-helpers/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-helpers/mod.ts @@ -24,6 +24,7 @@ export type { GetForIdInput, GetOrCreateWithKeyInput, GetWithKeyInput, + ListActorsInput, ManagerDisplayInformation, ManagerDriver, } from "@/manager/driver"; diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts index 4f672a414f..f25376b2f3 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/test-inline-client-driver.ts @@ -25,6 +25,7 @@ import { type GetOrCreateWithKeyInput, type GetWithKeyInput, HEADER_ACTOR_ID, + type ListActorsInput, type ManagerDisplayInformation, type ManagerDriver, } from "@/driver-helpers/mod"; @@ -73,6 +74,9 @@ export function createTestInlineClientDriver( input, ]); }, + listActors(input: ListActorsInput): Promise { + return makeInlineRequest(endpoint, encoding, "listActors", [input]); + }, async sendRequest( actorId: string, actorRequest: Request, @@ -129,7 +133,7 @@ export function createTestInlineClientDriver( if (errorData.error) { // Handle both error formats: // 1. { error: { code, message, metadata } } - structured format - // 2. { error: "message" } - simple string format (from custom onFetch handlers) + // 2. { error: "message" } - simple string format (from custom onRequest handlers) if (typeof errorData.error === "object") { throw new ClientActorError( errorData.error.code, @@ -138,7 +142,7 @@ export function createTestInlineClientDriver( ); } // For simple string errors, just return the response as-is - // This allows custom onFetch handlers to return their own error formats + // This allows custom onRequest handlers to return their own error formats } } catch (e) { // If it's not our error format, just return the response as-is @@ -452,7 +456,7 @@ export function createTestInlineClientDriver( // if (errorData.error) { // // Handle both error formats: // // 1. { error: { code, message, metadata } } - structured format - // // 2. { error: "message" } - simple string format (from custom onFetch handlers) + // // 2. { error: "message" } - simple string format (from custom onRequest handlers) // if (typeof errorData.error === "object") { // throw new ClientActorError( // errorData.error.code, @@ -461,7 +465,7 @@ export function createTestInlineClientDriver( // ); // } // // For simple string errors, just return the response as-is - // // This allows custom onFetch handlers to return their own error formats + // // This allows custom onRequest handlers to return their own error formats // } // } catch (e) { // // If it's not our error format, just return the response as-is diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/raw-http-direct-registry.ts b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/raw-http-direct-registry.ts index 359db4d54e..206b8f0e52 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/raw-http-direct-registry.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/raw-http-direct-registry.ts @@ -129,7 +129,7 @@ // expect(data).toEqual({ message: "Hello from actor!" }); // }); // -// test("should return 404 for actors without onFetch handler", async (c) => { +// test("should return 404 for actors without onRequest handler", async (c) => { // const { endpoint } = await setupDriverTest(c, driverTestConfig); // // const actorQuery: ActorQuery = { diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/raw-http-request-properties.ts b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/raw-http-request-properties.ts index 34f587fc06..201da0f1fd 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/raw-http-request-properties.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/raw-http-request-properties.ts @@ -8,7 +8,7 @@ export function runRawHttpRequestPropertiesTests( driverTestConfig: DriverTestConfig, ) { describe("raw http request properties", () => { - test("should pass all Request properties correctly to onFetch", async (c) => { + test("should pass all Request properties correctly to onRequest", async (c) => { const { client } = await setupDriverTest(c, driverTestConfig); const actor = client.rawHttpRequestPropertiesActor.getOrCreate([ "test", diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/raw-http.ts b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/raw-http.ts index ac20cfb39e..62ec39e661 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/raw-http.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/raw-http.ts @@ -79,7 +79,7 @@ export function runRawHttpTests(driverTestConfig: DriverTestConfig) { expect(response.status).toBe(404); }); - test("should return 404 when no onFetch handler defined", async (c) => { + test("should return 404 when no onRequest handler defined", async (c) => { const { client } = await setupDriverTest(c, driverTestConfig); const actor = client.rawHttpNoHandlerActor.getOrCreate([ "no-handler", @@ -89,10 +89,10 @@ export function runRawHttpTests(driverTestConfig: DriverTestConfig) { expect(response.ok).toBe(false); expect(response.status).toBe(404); - // No actions available without onFetch handler + // No actions available without onRequest handler }); - test("should return 500 error when onFetch returns void", async (c) => { + test("should return 500 error when onRequest returns void", async (c) => { const { client } = await setupDriverTest(c, driverTestConfig); const actor = client.rawHttpVoidReturnActor.getOrCreate([ "void-return", @@ -108,14 +108,14 @@ export function runRawHttpTests(driverTestConfig: DriverTestConfig) { message: string; }; expect(errorData.message).toContain( - "onFetch handler must return a Response", + "onRequest handler must return a Response", ); } catch { // If JSON parsing fails, just check that we got a 500 error // The error details are already validated by the status code } - // No actions available when onFetch returns void + // No actions available when onRequest returns void }); test("should handle different HTTP methods", async (c) => { diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/request-access.ts b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/request-access.ts index a5ac02f671..e11e6643d8 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/request-access.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/request-access.ts @@ -125,7 +125,7 @@ export function runRequestAccessTests(driverTestConfig: DriverTestConfig) { }); // TODO: re-expose this once we can have actor queries on the gateway - // test("should have access to request object in onFetch", async (c) => { + // test("should have access to request object in onRequest", async (c) => { // const { client, endpoint } = await setupDriverTest(c, driverTestConfig); // // // Create actor @@ -163,7 +163,7 @@ export function runRequestAccessTests(driverTestConfig: DriverTestConfig) { // expect(response.ok).toBe(true); // const data = await response.json(); // - // // Verify request info from onFetch + // // Verify request info from onRequest // expect((data as any).hasRequest).toBe(true); // expect((data as any).requestUrl).toContain("/test-path"); // expect((data as any).requestMethod).toBe("POST"); diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts index 2ae17efeb9..c47337d692 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -11,21 +11,18 @@ import { WSContext } from "hono/ws"; import invariant from "invariant"; import { lookupInRegistry } from "@/actor/definition"; import { KEYS } from "@/actor/instance/kv"; -import { ACTOR_INSTANCE_PERSIST_SYMBOL } from "@/actor/instance/mod"; import { deserializeActorKey } from "@/actor/keys"; -import { EncodingSchema } from "@/actor/protocol/serde"; import { type ActorRouter, createActorRouter } from "@/actor/router"; import { - handleRawWebSocketHandler, + handleRawWebSocket, handleWebSocketConnect, + parseWebSocketProtocols, truncateRawWebSocketPathPrefix, } from "@/actor/router-endpoints"; import type { Client } from "@/client/client"; import { PATH_CONNECT, PATH_WEBSOCKET_PREFIX, - WS_PROTOCOL_CONN_PARAMS, - WS_PROTOCOL_ENCODING, } from "@/common/actor-router-consts"; import type { UpgradeWebSocketArgs } from "@/common/inline-websocket-adapter2"; import { getLogger } from "@/common/log"; @@ -170,8 +167,7 @@ export class EngineActorDriver implements ActorDriver { // Check for existing WS const hibernatableArray = - handler.actor[ACTOR_INSTANCE_PERSIST_SYMBOL] - .hibernatableConns; + handler.actor.persist.hibernatableConns; logger().debug({ msg: "checking hibernatable websockets", requestId: idToStr(requestId), @@ -542,29 +538,7 @@ export class EngineActorDriver implements ActorDriver { // Parse configuration from Sec-WebSocket-Protocol header (optional for path-based routing) const protocols = request.headers.get("sec-websocket-protocol"); - - let encodingRaw: string | undefined; - let connParamsRaw: string | undefined; - - if (protocols) { - const protocolList = protocols.split(",").map((p) => p.trim()); - for (const protocol of protocolList) { - if (protocol.startsWith(WS_PROTOCOL_ENCODING)) { - encodingRaw = protocol.substring( - WS_PROTOCOL_ENCODING.length, - ); - } else if (protocol.startsWith(WS_PROTOCOL_CONN_PARAMS)) { - connParamsRaw = decodeURIComponent( - protocol.substring(WS_PROTOCOL_CONN_PARAMS.length), - ); - } - } - } - - const encoding = EncodingSchema.parse(encodingRaw); - const connParams = connParamsRaw - ? JSON.parse(connParamsRaw) - : undefined; + const { encoding, connParams } = parseWebSocketProtocols(protocols); // Fetch WS handler // @@ -582,7 +556,7 @@ export class EngineActorDriver implements ActorDriver { requestIdBuf, ); } else if (url.pathname.startsWith(PATH_WEBSOCKET_PREFIX)) { - wsHandlerPromise = handleRawWebSocketHandler( + wsHandlerPromise = handleRawWebSocket( request, url.pathname + url.search, this, @@ -623,11 +597,14 @@ export class EngineActorDriver implements ActorDriver { // - Queue WS acks const actorHandler = this.#actors.get(actorId); if (actorHandler?.actor) { - const hibernatableWs = actorHandler.actor[ - ACTOR_INSTANCE_PERSIST_SYMBOL - ].hibernatableConns.find((conn: any) => - arrayBuffersEqual(conn.hibernatableRequestId, requestIdBuf), - ); + const hibernatableWs = + actorHandler.actor.persist.hibernatableConns.find( + (conn: any) => + arrayBuffersEqual( + conn.hibernatableRequestId, + requestIdBuf, + ), + ); if (hibernatableWs) { // Track msgIndex for sending acks @@ -758,8 +735,7 @@ export class EngineActorDriver implements ActorDriver { const actorHandler = this.#actors.get(actorId); if (actorHandler?.actor) { const hibernatableArray = - actorHandler.actor[ACTOR_INSTANCE_PERSIST_SYMBOL] - .hibernatableConns; + actorHandler.actor.persist.hibernatableConns; const wsIndex = hibernatableArray.findIndex((conn: any) => arrayBuffersEqual(conn.hibernatableRequestId, requestIdBuf), ); diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/manager.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/manager.ts index adf5691d71..d080642900 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/manager.ts @@ -3,7 +3,7 @@ import invariant from "invariant"; import { generateConnRequestId } from "@/actor/conn/mod"; import { type ActorRouter, createActorRouter } from "@/actor/router"; import { - handleRawWebSocketHandler, + handleRawWebSocket, handleWebSocketConnect, } from "@/actor/router-endpoints"; import { createClientWithDriver } from "@/client/client"; @@ -17,6 +17,7 @@ import type { GetForIdInput, GetOrCreateWithKeyInput, GetWithKeyInput, + ListActorsInput, ManagerDriver, } from "@/driver-helpers/mod"; import { ManagerInspector } from "@/inspector/manager"; @@ -182,7 +183,7 @@ export class FileSystemManagerDriver implements ManagerDriver { ) { // Handle websocket proxy // Use the full path with query parameters - const wsHandler = await handleRawWebSocketHandler( + const wsHandler = await handleRawWebSocket( undefined, path, this.#actorDriver, @@ -239,7 +240,7 @@ export class FileSystemManagerDriver implements ManagerDriver { ) { // Handle websocket proxy // Use the full path with query parameters - const wsHandler = await handleRawWebSocketHandler( + const wsHandler = await handleRawWebSocket( c.req.raw, path, this.#actorDriver, @@ -333,6 +334,31 @@ export class FileSystemManagerDriver implements ManagerDriver { }; } + async listActors({ name }: ListActorsInput): Promise { + const actors: ActorOutput[] = []; + const itr = this.#state.getActorsIterator({}); + + for await (const actor of itr) { + if (actor.name === name) { + actors.push({ + actorId: actor.actorId, + name: actor.name, + key: actor.key as string[], + createTs: Number(actor.createdAt), + }); + } + } + + // Sort by create ts desc (most recent first) + actors.sort((a, b) => { + const aTs = a.createTs ?? 0; + const bTs = b.createTs ?? 0; + return bTs - aTs; + }); + + return actors; + } + displayInformation(): ManagerDisplayInformation { return { name: this.#state.persist ? "File System" : "Memory", diff --git a/rivetkit-typescript/packages/rivetkit/src/manager-api/actors.ts b/rivetkit-typescript/packages/rivetkit/src/manager-api/actors.ts index 8bdbdce334..aed3dac11d 100644 --- a/rivetkit-typescript/packages/rivetkit/src/manager-api/actors.ts +++ b/rivetkit-typescript/packages/rivetkit/src/manager-api/actors.ts @@ -15,6 +15,11 @@ export const ActorSchema = z.object({ }); export type Actor = z.infer; +export const ActorNameSchema = z.object({ + metadata: z.record(z.string(), z.unknown()), +}); +export type ActorName = z.infer; + // MARK: GET /actors export const ActorsListResponseSchema = z.object({ actors: z.array(ActorSchema), @@ -61,3 +66,11 @@ export type ActorsGetOrCreateResponse = z.infer< // MARK: DELETE /actors/{} export const ActorsDeleteResponseSchema = z.object({}); export type ActorsDeleteResponse = z.infer; + +// MARK: GET /actors/names +export const ActorsListNamesResponseSchema = z.object({ + names: z.record(z.string(), ActorNameSchema), +}); +export type ActorsListNamesResponse = z.infer< + typeof ActorsListNamesResponseSchema +>; diff --git a/rivetkit-typescript/packages/rivetkit/src/manager/driver.ts b/rivetkit-typescript/packages/rivetkit/src/manager/driver.ts index fe70445e96..22dcabd1ed 100644 --- a/rivetkit-typescript/packages/rivetkit/src/manager/driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/manager/driver.ts @@ -14,6 +14,7 @@ export interface ManagerDriver { getWithKey(input: GetWithKeyInput): Promise; getOrCreateWithKey(input: GetOrCreateWithKeyInput): Promise; createActor(input: CreateInput): Promise; + listActors(input: ListActorsInput): Promise; sendRequest(actorId: string, actorRequest: Request): Promise; openWebSocket( @@ -92,8 +93,16 @@ export interface CreateInput { region?: string; } +export interface ListActorsInput { + c?: HonoContext | undefined; + name: string; + key?: string; + includeDestroyed?: boolean; +} + export interface ActorOutput { actorId: string; name: string; key: ActorKey; + createTs?: number; } diff --git a/rivetkit-typescript/packages/rivetkit/src/manager/hono-websocket-adapter.ts b/rivetkit-typescript/packages/rivetkit/src/manager/hono-websocket-adapter.ts index 6ceab6651d..b9a98fa984 100644 --- a/rivetkit-typescript/packages/rivetkit/src/manager/hono-websocket-adapter.ts +++ b/rivetkit-typescript/packages/rivetkit/src/manager/hono-websocket-adapter.ts @@ -38,7 +38,7 @@ export class HonoWebSocketAdapter implements UniversalWebSocket { // The WSContext is already open when we receive it this.#readyState = this.OPEN; - // Immediately fire the open event + // Fire open event on next tick so the runtime has time to schedule event listeners setTimeout(() => { this.#fireEvent("open", { type: "open", diff --git a/rivetkit-typescript/packages/rivetkit/src/manager/router.ts b/rivetkit-typescript/packages/rivetkit/src/manager/router.ts index d6c4c455b4..b54523335e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/manager/router.ts +++ b/rivetkit-typescript/packages/rivetkit/src/manager/router.ts @@ -46,12 +46,14 @@ import { ActorsGetOrCreateRequestSchema, type ActorsGetOrCreateResponse, ActorsGetOrCreateResponseSchema, + type ActorsListNamesResponse, + ActorsListNamesResponseSchema, type ActorsListResponse, ActorsListResponseSchema, type Actor as ApiActor, } from "@/manager-api/actors"; import type { AnyClient } from "@/mod"; -import type { RegistryConfig } from "@/registry/config"; +import { buildActorNames, type RegistryConfig } from "@/registry/config"; import type { DriverConfig, RunnerConfig } from "@/registry/run-config"; import type { ActorOutput, ManagerDriver } from "./driver"; import { actorGateway, createTestWebSocketProxy } from "./gateway"; @@ -287,17 +289,7 @@ function addManagerRoutes( if (key && !name) { return c.json( { - error: "When providing 'key', 'name' must also be provided.", - }, - 400, - ); - } - - // Validate: must provide either actor_ids or (name + key) - if (!actorIdsParsed && !key) { - return c.json( - { - error: "Must provide either 'actor_ids' or both 'name' and 'key'.", + error: "Name is required when key is provided.", }, 400, ); @@ -349,16 +341,33 @@ function addManagerRoutes( } } } - } else if (key) { - // At this point, name is guaranteed to be defined due to validation above + } else if (key && name) { const actorOutput = await managerDriver.getWithKey({ c, - name: name!, + name, key: [key], // Convert string to ActorKey array }); if (actorOutput) { actors.push(actorOutput); } + } else { + if (!name) { + return c.json( + { + error: "Name is required when not using actor_ids.", + }, + 400, + ); + } + + // List all actors with the given name + const actorOutputs = await managerDriver.listActors({ + c, + name, + key, + includeDestroyed: false, + }); + actors.push(...actorOutputs); } return c.json({ @@ -369,6 +378,27 @@ function addManagerRoutes( }); } + // GET /actors/names + { + const route = createRoute({ + method: "get", + path: "/actors/names", + request: { + query: z.object({ + namespace: z.string(), + }), + }, + responses: buildOpenApiResponses(ActorsListNamesResponseSchema), + }); + + router.openapi(route, async (c) => { + const names = buildActorNames(registryConfig); + return c.json({ + names, + }); + }); + } + // PUT /actors { const route = createRoute({ @@ -699,7 +729,7 @@ function createApiActor( key: serializeActorKey(actor.key), namespace_id: "default", // Assert default namespace runner_name_selector: runnerName, - create_ts: Date.now(), + create_ts: actor.createTs ?? Date.now(), connectable_ts: null, destroy_ts: null, sleep_ts: null, diff --git a/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/api-endpoints.ts b/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/api-endpoints.ts index e23334c3c0..186b802f64 100644 --- a/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/api-endpoints.ts +++ b/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/api-endpoints.ts @@ -25,7 +25,7 @@ export async function getActor( ); } -// MARK: Get actor by id +// MARK: Get actor by key export async function getActorByKey( config: ClientConfig, name: string, @@ -39,6 +39,18 @@ export async function getActorByKey( ); } +// MARK: List actors by name +export async function listActorsByName( + config: ClientConfig, + name: string, +): Promise { + return apiCall( + config, + "GET", + `/actors?name=${encodeURIComponent(name)}`, + ); +} + // MARK: Get or create actor by id export async function getOrCreateActor( config: ClientConfig, diff --git a/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/api-utils.ts b/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/api-utils.ts index 1966270501..7499e2f054 100644 --- a/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/api-utils.ts +++ b/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/api-utils.ts @@ -1,3 +1,4 @@ +import { z } from "zod"; import type { ClientConfig } from "@/client/config"; import { sendHttpRequest } from "@/client/utils"; import { combineUrlPath } from "@/utils"; @@ -51,5 +52,12 @@ export async function apiCall( skipParseResponse: false, requestVersionedDataHandler: undefined, responseVersionedDataHandler: undefined, + requestZodSchema: z.any() as z.ZodType, + responseZodSchema: z.any() as z.ZodType, + // Identity conversions (passthrough for generic API calls) + requestToJson: (value) => value, + requestToBare: (value) => value, + responseFromJson: (value) => value, + responseFromBare: (value) => value, }); } diff --git a/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/mod.ts b/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/mod.ts index 62e734dada..e9eecef6e3 100644 --- a/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/mod.ts @@ -12,6 +12,7 @@ import type { GetForIdInput, GetOrCreateWithKeyInput, GetWithKeyInput, + ListActorsInput, ManagerDisplayInformation, ManagerDriver, } from "@/driver-helpers/mod"; @@ -30,6 +31,7 @@ import { getActorByKey, getMetadata, getOrCreateActor, + listActorsByName, } from "./api-endpoints"; import { EngineApiError, getEndpoint } from "./api-utils"; import { logger } from "./log"; @@ -255,6 +257,24 @@ export class RemoteManagerDriver implements ManagerDriver { }; } + async listActors({ c, name }: ListActorsInput): Promise { + // Wait for metadata check to complete if in progress + if (this.#metadataPromise) { + await this.#metadataPromise; + } + + logger().debug({ msg: "listing actors via engine api", name }); + + const response = await listActorsByName(this.#config, name); + + return response.actors.map((actor) => ({ + actorId: actor.actor_id, + name: actor.name, + key: deserializeActorKey(actor.key), + createTs: actor.create_ts, + })); + } + async destroyActor(actorId: string): Promise { // Wait for metadata check to complete if in progress if (this.#metadataPromise) { diff --git a/rivetkit-typescript/packages/rivetkit/src/schemas/client-protocol-zod/mod.ts b/rivetkit-typescript/packages/rivetkit/src/schemas/client-protocol-zod/mod.ts new file mode 100644 index 0000000000..8460f713ac --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/schemas/client-protocol-zod/mod.ts @@ -0,0 +1,103 @@ +import { z } from "zod"; + +// Helper schemas +const UintSchema = z.bigint(); +const OptionalUintSchema = UintSchema.nullable(); + +// MARK: Message To Client +export const InitSchema = z.object({ + actorId: z.string(), + connectionId: z.string(), +}); +export type Init = z.infer; + +export const ErrorSchema = z.object({ + group: z.string(), + code: z.string(), + message: z.string(), + metadata: z.unknown().optional(), + actionId: OptionalUintSchema, +}); +export type Error = z.infer; + +export const ActionResponseSchema = z.object({ + id: UintSchema, + output: z.unknown(), +}); +export type ActionResponse = z.infer; + +export const EventSchema = z.object({ + name: z.string(), + args: z.unknown(), +}); +export type Event = z.infer; + +export const ToClientBodySchema = z.discriminatedUnion("tag", [ + z.object({ tag: z.literal("Init"), val: InitSchema }), + z.object({ tag: z.literal("Error"), val: ErrorSchema }), + z.object({ tag: z.literal("ActionResponse"), val: ActionResponseSchema }), + z.object({ tag: z.literal("Event"), val: EventSchema }), +]); +export type ToClientBody = z.infer; + +export const ToClientSchema = z.object({ + body: ToClientBodySchema, +}); +export type ToClient = z.infer; + +// MARK: Message To Server +export const ActionRequestSchema = z.object({ + id: UintSchema, + name: z.string(), + args: z.unknown(), +}); +export type ActionRequest = z.infer; + +export const SubscriptionRequestSchema = z.object({ + eventName: z.string(), + subscribe: z.boolean(), +}); +export type SubscriptionRequest = z.infer; + +export const ToServerBodySchema = z.discriminatedUnion("tag", [ + z.object({ tag: z.literal("ActionRequest"), val: ActionRequestSchema }), + z.object({ + tag: z.literal("SubscriptionRequest"), + val: SubscriptionRequestSchema, + }), +]); +export type ToServerBody = z.infer; + +export const ToServerSchema = z.object({ + body: ToServerBodySchema, +}); +export type ToServer = z.infer; + +// MARK: HTTP Action +export const HttpActionRequestSchema = z.object({ + args: z.unknown(), +}); +export type HttpActionRequest = z.infer; + +export const HttpActionResponseSchema = z.object({ + output: z.unknown(), +}); +export type HttpActionResponse = z.infer; + +// MARK: HTTP Error +export const HttpResponseErrorSchema = z.object({ + group: z.string(), + code: z.string(), + message: z.string(), + metadata: z.unknown().optional(), +}); +export type HttpResponseError = z.infer; + +// MARK: HTTP Resolve +export const HttpResolveRequestSchema = z.null(); +export type HttpResolveRequest = z.infer; + +export const HttpResolveResponseSchema = z.object({ + actorId: z.string(), +}); +export type HttpResolveResponse = z.infer; diff --git a/rivetkit-typescript/packages/rivetkit/src/serde.ts b/rivetkit-typescript/packages/rivetkit/src/serde.ts index b51ad5dbb7..d006f2e115 100644 --- a/rivetkit-typescript/packages/rivetkit/src/serde.ts +++ b/rivetkit-typescript/packages/rivetkit/src/serde.ts @@ -1,5 +1,6 @@ import * as cbor from "cbor-x"; import invariant from "invariant"; +import type { z } from "zod"; import { assertUnreachable } from "@/common/utils"; import type { VersionedDataHandler } from "@/common/versioned-data"; import type { Encoding } from "@/mod"; @@ -52,46 +53,65 @@ export function wsBinaryTypeForEncoding( } } -export function serializeWithEncoding( +export function serializeWithEncoding( encoding: Encoding, value: T, - versionedDataHandler: VersionedDataHandler | undefined, + versionedDataHandler: VersionedDataHandler | undefined, + zodSchema: z.ZodType, + toJson: (value: T) => TJson, + toBare: (value: T) => TBare, ): Uint8Array | string { if (encoding === "json") { - return jsonStringifyCompat(value); + const jsonValue = toJson(value); + const validated = zodSchema.parse(jsonValue); + return jsonStringifyCompat(validated); } else if (encoding === "cbor") { - return cbor.encode(value); + const jsonValue = toJson(value); + const validated = zodSchema.parse(jsonValue); + return cbor.encode(validated); } else if (encoding === "bare") { if (!versionedDataHandler) { throw new Error( "VersionedDataHandler is required for 'bare' encoding", ); } - return versionedDataHandler.serializeWithEmbeddedVersion(value); + const bareValue = toBare(value); + return versionedDataHandler.serializeWithEmbeddedVersion(bareValue); } else { assertUnreachable(encoding); } } -export function deserializeWithEncoding( +export function deserializeWithEncoding( encoding: Encoding, buffer: Uint8Array | string, - versionedDataHandler: VersionedDataHandler | undefined, + versionedDataHandler: VersionedDataHandler | undefined, + zodSchema: z.ZodType, + fromJson: (value: TJson) => T, + fromBare: (value: TBare) => T, ): T { if (encoding === "json") { + let parsed: unknown; if (typeof buffer === "string") { - return jsonParseCompat(buffer); + parsed = jsonParseCompat(buffer); } else { const decoder = new TextDecoder("utf-8"); const jsonString = decoder.decode(buffer); - return jsonParseCompat(jsonString); + parsed = jsonParseCompat(jsonString); } + const validated = zodSchema.parse(parsed); + return fromJson(validated); } else if (encoding === "cbor") { invariant( typeof buffer !== "string", "buffer cannot be string for cbor encoding", ); - return cbor.decode(buffer); + // Decode CBOR to get JavaScript values (similar to JSON.parse) + const decoded: unknown = cbor.decode(buffer); + // Validate with Zod schema (CBOR produces same structure as JSON) + const validated = zodSchema.parse(decoded); + // CBOR decoding produces JS objects, use fromJson + return fromJson(validated); } else if (encoding === "bare") { invariant( typeof buffer !== "string", @@ -102,7 +122,9 @@ export function deserializeWithEncoding( "VersionedDataHandler is required for 'bare' encoding", ); } - return versionedDataHandler.deserializeWithEmbeddedVersion(buffer); + const bareValue = + versionedDataHandler.deserializeWithEmbeddedVersion(buffer); + return fromBare(bareValue); } else { assertUnreachable(encoding); } diff --git a/website/public/llms-full.txt b/website/public/llms-full.txt index 26bdf66b13..c1c784cffc 100644 --- a/website/public/llms-full.txt +++ b/website/public/llms-full.txt @@ -1713,7 +1713,7 @@ For HTTP requests, the router expects these headers: ```typescript // Direct HTTP request to actor -const response = await fetch("http://localhost:8080/registry/actors/myActor/raw/http/api/hello", +const response = await fetch("http://localhost:8080/registry/actors/myActor/request/api/hello", }), "X-RivetKit-Encoding": "json", "X-RivetKit-Conn-Params": JSON.stringify() @@ -1724,7 +1724,7 @@ const data = await response.json(); console.log(data); // // POST request with data -const postResponse = await fetch("http://localhost:8080/registry/actors/myActor/raw/http/api/echo", +const postResponse = await fetch("http://localhost:8080/registry/actors/myActor/request/api/echo", }), "X-RivetKit-Encoding": "json", "X-RivetKit-Conn-Params": JSON.stringify(), @@ -6341,6 +6341,11 @@ kubectl -n rivet-engine wait --for=condition=ready pod -l app=postgres --timeout ### 3. Deploy Rivet Engine +The Rivet Engine deployment consists of two components: + +- **Main Engine Deployment**: Runs all services except singleton services. Configured with Horizontal Pod Autoscaling (HPA) to automatically scale between 2-10 replicas based on CPU and memory utilization. +- **Singleton Engine Deployment**: Runs singleton services that must have exactly 1 replica (e.g., schedulers, coordinators). + Save as `rivet-engine.yaml`: ```yaml @@ -6382,7 +6387,7 @@ metadata: name: rivet-engine namespace: rivet-engine spec: - replicas: 1 + replicas: 2 selector: matchLabels: app: rivet-engine @@ -6396,6 +6401,8 @@ spec: image: rivetkit/engine:latest args: - start + - --except-services + - singleton env: - name: RIVET_CONFIG_PATH value: /etc/rivet/config.jsonc @@ -6410,16 +6417,107 @@ spec: readOnly: true resources: requests: - cpu: 500m - memory: 1Gi + cpu: 2000m + memory: 4Gi limits: + cpu: 4000m + memory: 8Gi + startupProbe: + httpGet: + path: /health + port: 6421 + initialDelaySeconds: 30 + periodSeconds: 10 + failureThreshold: 30 + readinessProbe: + httpGet: + path: /health + port: 6421 + periodSeconds: 5 + failureThreshold: 2 + livenessProbe: + httpGet: + path: /health + port: 6421 + periodSeconds: 10 + failureThreshold: 3 + volumes: + - name: config + configMap: + name: engine-config +--- +apiVersion: autoscaling/v2 +kind: HorizontalPodAutoscaler +metadata: + name: rivet-engine + namespace: rivet-engine +spec: + scaleTargetRef: + apiVersion: apps/v1 + kind: Deployment + name: rivet-engine + minReplicas: 2 + maxReplicas: 10 + metrics: + - type: Resource + resource: + name: cpu + target: + type: Utilization + averageUtilization: 60 + - type: Resource + resource: + name: memory + target: + type: Utilization + averageUtilization: 80 +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: rivet-engine-singleton + namespace: rivet-engine +spec: + replicas: 1 + selector: + matchLabels: + app: rivet-engine-singleton + template: + metadata: + labels: + app: rivet-engine-singleton + spec: + containers: + - name: rivet-engine + image: rivetkit/engine:latest + args: + - start + - --services + - singleton + - --services + - api-peer + env: + - name: RIVET_CONFIG_PATH + value: /etc/rivet/config.jsonc + ports: + - containerPort: 6421 + name: api-peer + volumeMounts: + - name: config + mountPath: /etc/rivet + readOnly: true + resources: + requests: cpu: 2000m memory: 4Gi + limits: + cpu: 4000m + memory: 8Gi startupProbe: httpGet: path: /health port: 6421 - initialDelaySeconds: 10 + initialDelaySeconds: 30 periodSeconds: 10 failureThreshold: 30 readinessProbe: @@ -6427,11 +6525,13 @@ spec: path: /health port: 6421 periodSeconds: 5 + failureThreshold: 2 livenessProbe: httpGet: path: /health port: 6421 periodSeconds: 10 + failureThreshold: 3 volumes: - name: config configMap: @@ -6443,9 +6543,21 @@ Apply and wait for the engine to be ready: ```bash kubectl apply -f rivet-engine.yaml kubectl -n rivet-engine wait --for=condition=ready pod -l app=rivet-engine --timeout=300s +kubectl -n rivet-engine wait --for=condition=ready pod -l app=rivet-engine-singleton --timeout=300s ``` -### 4. Access the Engine +**Note**: The HPA requires a metrics server to be running in your cluster. Most Kubernetes distributions (including k3d, GKE, EKS, AKS) include this by default. + +### 4. Verify Deployment + +Check that all pods are running (you should see 2+ engine pods and 1 singleton pod): + +```bash +kubectl -n rivet-engine get pods +kubectl -n rivet-engine get hpa +``` + +### 5. Access the Engine Get the service URL: @@ -6528,14 +6640,44 @@ k3d cluster delete rivet ### Scaling -For horizontal scaling, update the deployment: +The engine is configured with Horizontal Pod Autoscaling (HPA) by default, automatically scaling between 2-10 replicas based on CPU (60%) and memory (80%) utilization. + +To adjust the scaling parameters, modify the HPA configuration: ```yaml +apiVersion: autoscaling/v2 +kind: HorizontalPodAutoscaler +metadata: + name: rivet-engine + namespace: rivet-engine spec: - replicas: 3 # Multiple replicas -``` + scaleTargetRef: + apiVersion: apps/v1 + kind: Deployment + name: rivet-engine + minReplicas: 2 # Adjust minimum replicas + maxReplicas: 20 # Adjust maximum replicas + metrics: + - type: Resource + resource: + name: cpu + target: + type: Utilization + averageUtilization: 70 # Adjust CPU threshold + - type: Resource + resource: + name: memory + target: + type: Utilization + averageUtilization: 80 # Adjust memory threshold +``` + +Monitor HPA status: -See our [HPA set up on github](https://github.com/rivet-gg/rivet/tree/main/k8s/engines/05-rivet-engine-hpa.yaml) for info on configuring automatic horizontal scaling. +```bash +kubectl -n rivet-engine get hpa +kubectl -n rivet-engine describe hpa rivet-engine +``` ## Next Steps diff --git a/website/public/llms.txt b/website/public/llms.txt index b2d6fa1cf4..92b2c848d1 100644 --- a/website/public/llms.txt +++ b/website/public/llms.txt @@ -24,8 +24,11 @@ https://rivet.dev/blog/2025-10-01-railway-selfhost https://rivet.dev/blog/2025-10-05-weekly-updates https://rivet.dev/blog/2025-10-09-rivet-cloud-launch https://rivet.dev/blog/2025-10-17-rivet-actors-vercel +https://rivet.dev/blog/2025-10-19-weekly-updates https://rivet.dev/blog/2025-10-20-how-we-built-websocket-servers-for-vercel-functions https://rivet.dev/blog/2025-10-20-weekly-updates +https://rivet.dev/blog/2025-10-24-weekly-updates +https://rivet.dev/blog/2025-11-02-weekly-updates https://rivet.dev/blog/godot-multiplayer-compared-to-unity https://rivet.dev/changelog https://rivet.dev/changelog.json @@ -50,8 +53,11 @@ https://rivet.dev/changelog/2025-10-01-railway-selfhost https://rivet.dev/changelog/2025-10-05-weekly-updates https://rivet.dev/changelog/2025-10-09-rivet-cloud-launch https://rivet.dev/changelog/2025-10-17-rivet-actors-vercel +https://rivet.dev/changelog/2025-10-19-weekly-updates https://rivet.dev/changelog/2025-10-20-how-we-built-websocket-servers-for-vercel-functions https://rivet.dev/changelog/2025-10-20-weekly-updates +https://rivet.dev/changelog/2025-10-24-weekly-updates +https://rivet.dev/changelog/2025-11-02-weekly-updates https://rivet.dev/changelog/godot-multiplayer-compared-to-unity https://rivet.dev/cloud https://rivet.dev/docs/actors