From 94422d8189623cb5290463a1cc30aa5820d3ad41 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Sun, 9 Nov 2025 16:40:47 -0800 Subject: [PATCH] refactor(rivetkit): extract WebSocket protocol parsing to shared utility --- .../rivetkit/src/actor/router-endpoints.ts | 35 ++++++++++++++++++- .../packages/rivetkit/src/actor/router.ts | 34 ++---------------- .../src/drivers/engine/actor-driver.ts | 28 ++------------- 3 files changed, 39 insertions(+), 58 deletions(-) diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts index 4a6cd77a46..19f499e437 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts @@ -14,6 +14,8 @@ import { HEADER_ACTOR_QUERY, HEADER_CONN_PARAMS, HEADER_ENCODING, + WS_PROTOCOL_CONN_PARAMS, + WS_PROTOCOL_ENCODING, } from "@/common/actor-router-consts"; import type { UpgradeWebSocketArgs } from "@/common/inline-websocket-adapter2"; import { deconstructError, stringifyError } from "@/common/utils"; @@ -517,10 +519,12 @@ export async function handleRawWebSocket( } // Helper to get the connection encoding from a request +// +// Defaults to JSON if not provided so we can support vanilla curl requests easily. export function getRequestEncoding(req: HonoRequest): Encoding { const encodingParam = req.header(HEADER_ENCODING); if (!encodingParam) { - throw new errors.InvalidEncoding("undefined"); + return "json"; } const result = EncodingSchema.safeParse(encodingParam); @@ -570,6 +574,35 @@ export function getRequestConnParams(req: HonoRequest): unknown { } } +/** + * Parse encoding and connection parameters from WebSocket Sec-WebSocket-Protocol header + */ +export function parseWebSocketProtocols(protocols: string | null | undefined): { + encoding: Encoding; + connParams: unknown; +} { + let encodingRaw: string | undefined; + let connParamsRaw: string | undefined; + + if (protocols) { + const protocolList = protocols.split(",").map((p) => p.trim()); + for (const protocol of protocolList) { + if (protocol.startsWith(WS_PROTOCOL_ENCODING)) { + encodingRaw = protocol.substring(WS_PROTOCOL_ENCODING.length); + } else if (protocol.startsWith(WS_PROTOCOL_CONN_PARAMS)) { + connParamsRaw = decodeURIComponent( + protocol.substring(WS_PROTOCOL_CONN_PARAMS.length), + ); + } + } + } + + const encoding = EncodingSchema.parse(encodingRaw); + const connParams = connParamsRaw ? JSON.parse(connParamsRaw) : undefined; + + return { encoding, connParams }; +} + /** * Truncase the PATH_WEBSOCKET_PREFIX path prefix in order to pass a clean * path to the onWebSocket handler. diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router.ts index d61f180c7b..d56c980b30 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router.ts @@ -1,6 +1,5 @@ import { Hono } from "hono"; import invariant from "invariant"; -import { EncodingSchema } from "@/actor/protocol/serde"; import { type ActionOpts, type ActionOutput, @@ -11,12 +10,11 @@ import { handleRawRequest, handleRawWebSocket, handleWebSocketConnect, + parseWebSocketProtocols, } from "@/actor/router-endpoints"; import { PATH_CONNECT, PATH_WEBSOCKET_PREFIX, - WS_PROTOCOL_CONN_PARAMS, - WS_PROTOCOL_ENCODING, } from "@/common/actor-router-consts"; import { handleRouteError, @@ -114,34 +112,8 @@ export function createActorRouter( return upgradeWebSocket(async (c) => { // Parse configuration from Sec-WebSocket-Protocol header const protocols = c.req.header("sec-websocket-protocol"); - let encodingRaw: string | undefined; - let connParamsRaw: string | undefined; - - if (protocols) { - const protocolList = protocols - .split(",") - .map((p) => p.trim()); - for (const protocol of protocolList) { - if (protocol.startsWith(WS_PROTOCOL_ENCODING)) { - encodingRaw = protocol.substring( - WS_PROTOCOL_ENCODING.length, - ); - } else if ( - protocol.startsWith(WS_PROTOCOL_CONN_PARAMS) - ) { - connParamsRaw = decodeURIComponent( - protocol.substring( - WS_PROTOCOL_CONN_PARAMS.length, - ), - ); - } - } - } - - const encoding = EncodingSchema.parse(encodingRaw); - const connParams = connParamsRaw - ? JSON.parse(connParamsRaw) - : undefined; + const { encoding, connParams } = + parseWebSocketProtocols(protocols); return await handleWebSocketConnect( c.req.raw, diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts index 7f86026898..799e6bdd26 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -13,19 +13,17 @@ import { lookupInRegistry } from "@/actor/definition"; import { KEYS } from "@/actor/instance/kv"; import { ACTOR_INSTANCE_PERSIST_SYMBOL } from "@/actor/instance/mod"; import { deserializeActorKey } from "@/actor/keys"; -import { EncodingSchema } from "@/actor/protocol/serde"; import { type ActorRouter, createActorRouter } from "@/actor/router"; import { handleRawWebSocket, handleWebSocketConnect, + parseWebSocketProtocols, truncateRawWebSocketPathPrefix, } from "@/actor/router-endpoints"; import type { Client } from "@/client/client"; import { PATH_CONNECT, PATH_WEBSOCKET_PREFIX, - WS_PROTOCOL_CONN_PARAMS, - WS_PROTOCOL_ENCODING, } from "@/common/actor-router-consts"; import type { UpgradeWebSocketArgs } from "@/common/inline-websocket-adapter2"; import { getLogger } from "@/common/log"; @@ -542,29 +540,7 @@ export class EngineActorDriver implements ActorDriver { // Parse configuration from Sec-WebSocket-Protocol header (optional for path-based routing) const protocols = request.headers.get("sec-websocket-protocol"); - - let encodingRaw: string | undefined; - let connParamsRaw: string | undefined; - - if (protocols) { - const protocolList = protocols.split(",").map((p) => p.trim()); - for (const protocol of protocolList) { - if (protocol.startsWith(WS_PROTOCOL_ENCODING)) { - encodingRaw = protocol.substring( - WS_PROTOCOL_ENCODING.length, - ); - } else if (protocol.startsWith(WS_PROTOCOL_CONN_PARAMS)) { - connParamsRaw = decodeURIComponent( - protocol.substring(WS_PROTOCOL_CONN_PARAMS.length), - ); - } - } - } - - const encoding = EncodingSchema.parse(encodingRaw); - const connParams = connParamsRaw - ? JSON.parse(connParamsRaw) - : undefined; + const { encoding, connParams } = parseWebSocketProtocols(protocols); // Fetch WS handler //