Skip to content

Commit 14348ed

Browse files
committed
refactor(rivetkit): extract WebSocket protocol parsing to shared utility
1 parent 27c352a commit 14348ed

File tree

3 files changed

+39
-58
lines changed

3 files changed

+39
-58
lines changed

rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import {
1414
HEADER_ACTOR_QUERY,
1515
HEADER_CONN_PARAMS,
1616
HEADER_ENCODING,
17+
WS_PROTOCOL_CONN_PARAMS,
18+
WS_PROTOCOL_ENCODING,
1719
} from "@/common/actor-router-consts";
1820
import type { UpgradeWebSocketArgs } from "@/common/inline-websocket-adapter2";
1921
import { deconstructError, stringifyError } from "@/common/utils";
@@ -517,10 +519,12 @@ export async function handleRawWebSocket(
517519
}
518520

519521
// Helper to get the connection encoding from a request
522+
//
523+
// Defaults to JSON if not provided so we can support vanilla curl requests easily.
520524
export function getRequestEncoding(req: HonoRequest): Encoding {
521525
const encodingParam = req.header(HEADER_ENCODING);
522526
if (!encodingParam) {
523-
throw new errors.InvalidEncoding("undefined");
527+
return "json";
524528
}
525529

526530
const result = EncodingSchema.safeParse(encodingParam);
@@ -570,6 +574,35 @@ export function getRequestConnParams(req: HonoRequest): unknown {
570574
}
571575
}
572576

577+
/**
578+
* Parse encoding and connection parameters from WebSocket Sec-WebSocket-Protocol header
579+
*/
580+
export function parseWebSocketProtocols(protocols: string | null | undefined): {
581+
encoding: Encoding;
582+
connParams: unknown;
583+
} {
584+
let encodingRaw: string | undefined;
585+
let connParamsRaw: string | undefined;
586+
587+
if (protocols) {
588+
const protocolList = protocols.split(",").map((p) => p.trim());
589+
for (const protocol of protocolList) {
590+
if (protocol.startsWith(WS_PROTOCOL_ENCODING)) {
591+
encodingRaw = protocol.substring(WS_PROTOCOL_ENCODING.length);
592+
} else if (protocol.startsWith(WS_PROTOCOL_CONN_PARAMS)) {
593+
connParamsRaw = decodeURIComponent(
594+
protocol.substring(WS_PROTOCOL_CONN_PARAMS.length),
595+
);
596+
}
597+
}
598+
}
599+
600+
const encoding = EncodingSchema.parse(encodingRaw);
601+
const connParams = connParamsRaw ? JSON.parse(connParamsRaw) : undefined;
602+
603+
return { encoding, connParams };
604+
}
605+
573606
/**
574607
* Truncase the PATH_WEBSOCKET_PREFIX path prefix in order to pass a clean
575608
* path to the onWebSocket handler.

rivetkit-typescript/packages/rivetkit/src/actor/router.ts

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import { Hono } from "hono";
22
import invariant from "invariant";
3-
import { EncodingSchema } from "@/actor/protocol/serde";
43
import {
54
type ActionOpts,
65
type ActionOutput,
@@ -11,12 +10,11 @@ import {
1110
handleRawRequest,
1211
handleRawWebSocket,
1312
handleWebSocketConnect,
13+
parseWebSocketProtocols,
1414
} from "@/actor/router-endpoints";
1515
import {
1616
PATH_CONNECT,
1717
PATH_WEBSOCKET_PREFIX,
18-
WS_PROTOCOL_CONN_PARAMS,
19-
WS_PROTOCOL_ENCODING,
2018
} from "@/common/actor-router-consts";
2119
import {
2220
handleRouteError,
@@ -114,34 +112,8 @@ export function createActorRouter(
114112
return upgradeWebSocket(async (c) => {
115113
// Parse configuration from Sec-WebSocket-Protocol header
116114
const protocols = c.req.header("sec-websocket-protocol");
117-
let encodingRaw: string | undefined;
118-
let connParamsRaw: string | undefined;
119-
120-
if (protocols) {
121-
const protocolList = protocols
122-
.split(",")
123-
.map((p) => p.trim());
124-
for (const protocol of protocolList) {
125-
if (protocol.startsWith(WS_PROTOCOL_ENCODING)) {
126-
encodingRaw = protocol.substring(
127-
WS_PROTOCOL_ENCODING.length,
128-
);
129-
} else if (
130-
protocol.startsWith(WS_PROTOCOL_CONN_PARAMS)
131-
) {
132-
connParamsRaw = decodeURIComponent(
133-
protocol.substring(
134-
WS_PROTOCOL_CONN_PARAMS.length,
135-
),
136-
);
137-
}
138-
}
139-
}
140-
141-
const encoding = EncodingSchema.parse(encodingRaw);
142-
const connParams = connParamsRaw
143-
? JSON.parse(connParamsRaw)
144-
: undefined;
115+
const { encoding, connParams } =
116+
parseWebSocketProtocols(protocols);
145117

146118
return await handleWebSocketConnect(
147119
c.req.raw,

rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,17 @@ import { lookupInRegistry } from "@/actor/definition";
1313
import { KEYS } from "@/actor/instance/kv";
1414
import { ACTOR_INSTANCE_PERSIST_SYMBOL } from "@/actor/instance/mod";
1515
import { deserializeActorKey } from "@/actor/keys";
16-
import { EncodingSchema } from "@/actor/protocol/serde";
1716
import { type ActorRouter, createActorRouter } from "@/actor/router";
1817
import {
1918
handleRawWebSocket,
2019
handleWebSocketConnect,
20+
parseWebSocketProtocols,
2121
truncateRawWebSocketPathPrefix,
2222
} from "@/actor/router-endpoints";
2323
import type { Client } from "@/client/client";
2424
import {
2525
PATH_CONNECT,
2626
PATH_WEBSOCKET_PREFIX,
27-
WS_PROTOCOL_CONN_PARAMS,
28-
WS_PROTOCOL_ENCODING,
2927
} from "@/common/actor-router-consts";
3028
import type { UpgradeWebSocketArgs } from "@/common/inline-websocket-adapter2";
3129
import { getLogger } from "@/common/log";
@@ -542,29 +540,7 @@ export class EngineActorDriver implements ActorDriver {
542540

543541
// Parse configuration from Sec-WebSocket-Protocol header (optional for path-based routing)
544542
const protocols = request.headers.get("sec-websocket-protocol");
545-
546-
let encodingRaw: string | undefined;
547-
let connParamsRaw: string | undefined;
548-
549-
if (protocols) {
550-
const protocolList = protocols.split(",").map((p) => p.trim());
551-
for (const protocol of protocolList) {
552-
if (protocol.startsWith(WS_PROTOCOL_ENCODING)) {
553-
encodingRaw = protocol.substring(
554-
WS_PROTOCOL_ENCODING.length,
555-
);
556-
} else if (protocol.startsWith(WS_PROTOCOL_CONN_PARAMS)) {
557-
connParamsRaw = decodeURIComponent(
558-
protocol.substring(WS_PROTOCOL_CONN_PARAMS.length),
559-
);
560-
}
561-
}
562-
}
563-
564-
const encoding = EncodingSchema.parse(encodingRaw);
565-
const connParams = connParamsRaw
566-
? JSON.parse(connParamsRaw)
567-
: undefined;
543+
const { encoding, connParams } = parseWebSocketProtocols(protocols);
568544

569545
// Fetch WS handler
570546
//

0 commit comments

Comments
 (0)