Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import {
HEADER_ACTOR_QUERY,
HEADER_CONN_PARAMS,
HEADER_ENCODING,
WS_PROTOCOL_CONN_PARAMS,
WS_PROTOCOL_ENCODING,
} from "@/common/actor-router-consts";
import type { UpgradeWebSocketArgs } from "@/common/inline-websocket-adapter2";
import { deconstructError, stringifyError } from "@/common/utils";
Expand Down Expand Up @@ -517,10 +519,12 @@ export async function handleRawWebSocket(
}

// Helper to get the connection encoding from a request
//
// Defaults to JSON if not provided so we can support vanilla curl requests easily.
export function getRequestEncoding(req: HonoRequest): Encoding {
const encodingParam = req.header(HEADER_ENCODING);
if (!encodingParam) {
throw new errors.InvalidEncoding("undefined");
return "json";
}

const result = EncodingSchema.safeParse(encodingParam);
Expand Down Expand Up @@ -570,6 +574,35 @@ export function getRequestConnParams(req: HonoRequest): unknown {
}
}

/**
* Parse encoding and connection parameters from WebSocket Sec-WebSocket-Protocol header
*/
export function parseWebSocketProtocols(protocols: string | null | undefined): {
encoding: Encoding;
connParams: unknown;
} {
let encodingRaw: string | undefined;
let connParamsRaw: string | undefined;

if (protocols) {
const protocolList = protocols.split(",").map((p) => p.trim());
for (const protocol of protocolList) {
if (protocol.startsWith(WS_PROTOCOL_ENCODING)) {
encodingRaw = protocol.substring(WS_PROTOCOL_ENCODING.length);
} else if (protocol.startsWith(WS_PROTOCOL_CONN_PARAMS)) {
connParamsRaw = decodeURIComponent(
protocol.substring(WS_PROTOCOL_CONN_PARAMS.length),
);
}
}
}

const encoding = EncodingSchema.parse(encodingRaw);
const connParams = connParamsRaw ? JSON.parse(connParamsRaw) : undefined;

return { encoding, connParams };
}

/**
* Truncase the PATH_WEBSOCKET_PREFIX path prefix in order to pass a clean
* path to the onWebSocket handler.
Expand Down
34 changes: 3 additions & 31 deletions rivetkit-typescript/packages/rivetkit/src/actor/router.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { Hono } from "hono";
import invariant from "invariant";
import { EncodingSchema } from "@/actor/protocol/serde";
import {
type ActionOpts,
type ActionOutput,
Expand All @@ -11,12 +10,11 @@ import {
handleRawRequest,
handleRawWebSocket,
handleWebSocketConnect,
parseWebSocketProtocols,
} from "@/actor/router-endpoints";
import {
PATH_CONNECT,
PATH_WEBSOCKET_PREFIX,
WS_PROTOCOL_CONN_PARAMS,
WS_PROTOCOL_ENCODING,
} from "@/common/actor-router-consts";
import {
handleRouteError,
Expand Down Expand Up @@ -114,34 +112,8 @@ export function createActorRouter(
return upgradeWebSocket(async (c) => {
// Parse configuration from Sec-WebSocket-Protocol header
const protocols = c.req.header("sec-websocket-protocol");
let encodingRaw: string | undefined;
let connParamsRaw: string | undefined;

if (protocols) {
const protocolList = protocols
.split(",")
.map((p) => p.trim());
for (const protocol of protocolList) {
if (protocol.startsWith(WS_PROTOCOL_ENCODING)) {
encodingRaw = protocol.substring(
WS_PROTOCOL_ENCODING.length,
);
} else if (
protocol.startsWith(WS_PROTOCOL_CONN_PARAMS)
) {
connParamsRaw = decodeURIComponent(
protocol.substring(
WS_PROTOCOL_CONN_PARAMS.length,
),
);
}
}
}

const encoding = EncodingSchema.parse(encodingRaw);
const connParams = connParamsRaw
? JSON.parse(connParamsRaw)
: undefined;
const { encoding, connParams } =
parseWebSocketProtocols(protocols);

return await handleWebSocketConnect(
c.req.raw,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,17 @@ import { lookupInRegistry } from "@/actor/definition";
import { KEYS } from "@/actor/instance/kv";
import { ACTOR_INSTANCE_PERSIST_SYMBOL } from "@/actor/instance/mod";
import { deserializeActorKey } from "@/actor/keys";
import { EncodingSchema } from "@/actor/protocol/serde";
import { type ActorRouter, createActorRouter } from "@/actor/router";
import {
handleRawWebSocket,
handleWebSocketConnect,
parseWebSocketProtocols,
truncateRawWebSocketPathPrefix,
} from "@/actor/router-endpoints";
import type { Client } from "@/client/client";
import {
PATH_CONNECT,
PATH_WEBSOCKET_PREFIX,
WS_PROTOCOL_CONN_PARAMS,
WS_PROTOCOL_ENCODING,
} from "@/common/actor-router-consts";
import type { UpgradeWebSocketArgs } from "@/common/inline-websocket-adapter2";
import { getLogger } from "@/common/log";
Expand Down Expand Up @@ -542,29 +540,7 @@ export class EngineActorDriver implements ActorDriver {

// Parse configuration from Sec-WebSocket-Protocol header (optional for path-based routing)
const protocols = request.headers.get("sec-websocket-protocol");

let encodingRaw: string | undefined;
let connParamsRaw: string | undefined;

if (protocols) {
const protocolList = protocols.split(",").map((p) => p.trim());
for (const protocol of protocolList) {
if (protocol.startsWith(WS_PROTOCOL_ENCODING)) {
encodingRaw = protocol.substring(
WS_PROTOCOL_ENCODING.length,
);
} else if (protocol.startsWith(WS_PROTOCOL_CONN_PARAMS)) {
connParamsRaw = decodeURIComponent(
protocol.substring(WS_PROTOCOL_CONN_PARAMS.length),
);
}
}
}

const encoding = EncodingSchema.parse(encodingRaw);
const connParams = connParamsRaw
? JSON.parse(connParamsRaw)
: undefined;
const { encoding, connParams } = parseWebSocketProtocols(protocols);

// Fetch WS handler
//
Expand Down
Loading