From be3dc7d4b6b0284f1449fd355c898625f093b7da Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Sun, 9 Nov 2025 17:30:33 -0800 Subject: [PATCH] chore(rivetkit): add required zod validation for json encoding --- .../actor/client-protocol-schema-json/mod.ts | 105 ++++++++++++++++++ .../packages/rivetkit/src/actor/conn/mod.ts | 2 + .../src/actor/instance/event-manager.ts | 2 + .../rivetkit/src/actor/instance/mod.ts | 2 + .../rivetkit/src/actor/protocol/old.ts | 13 ++- .../rivetkit/src/actor/protocol/serde.ts | 15 ++- .../rivetkit/src/actor/router-endpoints.ts | 6 + .../rivetkit/src/client/actor-conn.ts | 6 + .../rivetkit/src/client/actor-handle.ts | 6 + .../packages/rivetkit/src/client/utils.ts | 28 ++++- .../packages/rivetkit/src/common/router.ts | 2 + .../src/remote-manager-driver/api-utils.ts | 3 + .../packages/rivetkit/src/serde.ts | 22 ++-- 13 files changed, 193 insertions(+), 19 deletions(-) create mode 100644 rivetkit-typescript/packages/rivetkit/src/actor/client-protocol-schema-json/mod.ts diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/client-protocol-schema-json/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/client-protocol-schema-json/mod.ts new file mode 100644 index 0000000000..72b53ed0fe --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/client-protocol-schema-json/mod.ts @@ -0,0 +1,105 @@ +import { z } from "zod"; + +// Helper schemas for ArrayBuffer handling in JSON +const ArrayBufferSchema = z.instanceof(ArrayBuffer); +const OptionalArrayBufferSchema = ArrayBufferSchema.nullable(); +const UintSchema = z.bigint(); +const OptionalUintSchema = UintSchema.nullable(); + +// MARK: Message To Client +export const InitSchema = z.object({ + actorId: z.string(), + connectionId: z.string(), +}); +export type Init = z.infer; + +export const ErrorSchema = z.object({ + group: z.string(), + code: z.string(), + message: z.string(), + metadata: OptionalArrayBufferSchema, + actionId: OptionalUintSchema, +}); +export type Error = z.infer; + +export const ActionResponseSchema = z.object({ + id: UintSchema, + output: ArrayBufferSchema, +}); +export type ActionResponse = z.infer; + +export const EventSchema = z.object({ + name: z.string(), + args: ArrayBufferSchema, +}); +export type Event = z.infer; + +export const ToClientBodySchema = z.discriminatedUnion("tag", [ + z.object({ tag: z.literal("Init"), val: InitSchema }), + z.object({ tag: z.literal("Error"), val: ErrorSchema }), + z.object({ tag: z.literal("ActionResponse"), val: ActionResponseSchema }), + z.object({ tag: z.literal("Event"), val: EventSchema }), +]); +export type ToClientBody = z.infer; + +export const ToClientSchema = z.object({ + body: ToClientBodySchema, +}); +export type ToClient = z.infer; + +// MARK: Message To Server +export const ActionRequestSchema = z.object({ + id: UintSchema, + name: z.string(), + args: ArrayBufferSchema, +}); +export type ActionRequest = z.infer; + +export const SubscriptionRequestSchema = z.object({ + eventName: z.string(), + subscribe: z.boolean(), +}); +export type SubscriptionRequest = z.infer; + +export const ToServerBodySchema = z.discriminatedUnion("tag", [ + z.object({ tag: z.literal("ActionRequest"), val: ActionRequestSchema }), + z.object({ + tag: z.literal("SubscriptionRequest"), + val: SubscriptionRequestSchema, + }), +]); +export type ToServerBody = z.infer; + +export const ToServerSchema = z.object({ + body: ToServerBodySchema, +}); +export type ToServer = z.infer; + +// MARK: HTTP Action +export const HttpActionRequestSchema = z.object({ + args: ArrayBufferSchema, +}); +export type HttpActionRequest = z.infer; + +export const HttpActionResponseSchema = z.object({ + output: ArrayBufferSchema, +}); +export type HttpActionResponse = z.infer; + +// MARK: HTTP Error +export const HttpResponseErrorSchema = z.object({ + group: z.string(), + code: z.string(), + message: z.string(), + metadata: OptionalArrayBufferSchema, +}); +export type HttpResponseError = z.infer; + +// MARK: HTTP Resolve +export const HttpResolveRequestSchema = z.null(); +export type HttpResolveRequest = z.infer; + +export const HttpResolveResponseSchema = z.object({ + actorId: z.string(), +}); +export type HttpResolveResponse = z.infer; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts index 25d8395722..fc35744165 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts @@ -1,4 +1,5 @@ import * as cbor from "cbor-x"; +import { ToClientSchema } from "@/actor/client-protocol-schema-json/mod"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils"; @@ -205,6 +206,7 @@ export class Conn { }, }, TO_CLIENT_VERSIONED, + ToClientSchema, ), ); } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts index 1f67fd6e1a..a8ce4a868f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts @@ -1,4 +1,5 @@ import * as cbor from "cbor-x"; +import { ToClientSchema } from "@/actor/client-protocol-schema-json/mod"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; import { bufferToArrayBuffer } from "@/utils"; @@ -190,6 +191,7 @@ export class EventManager { }, }, TO_CLIENT_VERSIONED, + ToClientSchema, ); // Send to all subscribers diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts index 4170993b98..fdcff7bdff 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts @@ -1,5 +1,6 @@ import * as cbor from "cbor-x"; import invariant from "invariant"; +import { ToClientSchema } from "@/actor/client-protocol-schema-json/mod"; import type { ActorKey } from "@/actor/mod"; import type { Client } from "@/client/client"; import { getBaseLogger, getIncludeTarget, type Logger } from "@/common/log"; @@ -522,6 +523,7 @@ export class ActorInstance { }, }, TO_CLIENT_VERSIONED, + ToClientSchema, ), ); diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts index feadf2d244..37aed9dc91 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts @@ -1,5 +1,9 @@ import * as cbor from "cbor-x"; import { z } from "zod"; +import { + ToClientSchema, + ToServerSchema, +} from "@/actor/client-protocol-schema-json/mod"; import type { AnyDatabaseProvider } from "@/actor/database"; import * as errors from "@/actor/errors"; import { @@ -81,7 +85,12 @@ export async function parseMessage( } // Deserialize message - return deserializeWithEncoding(opts.encoding, buffer, TO_SERVER_VERSIONED); + return deserializeWithEncoding( + opts.encoding, + buffer, + TO_SERVER_VERSIONED, + ToServerSchema, + ); } export interface ProcessMessageHandler< @@ -171,6 +180,7 @@ export async function processMessage< }, }, TO_CLIENT_VERSIONED, + ToClientSchema, ), ); @@ -243,6 +253,7 @@ export async function processMessage< }, }, TO_CLIENT_VERSIONED, + ToClientSchema, ), ); diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/serde.ts b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/serde.ts index 1c84128f88..157e6d3aa5 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/serde.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/serde.ts @@ -22,17 +22,23 @@ export type Encoding = z.infer; /** * Helper class that helps serialize data without re-serializing for the same encoding. */ -export class CachedSerializer { - #data: T; +export class CachedSerializer { + #data: T | TJson; #cache = new Map(); #versionedDataHandler: VersionedDataHandler; + #zodSchema: z.ZodType; - constructor(data: T, versionedDataHandler: VersionedDataHandler) { + constructor( + data: T | TJson, + versionedDataHandler: VersionedDataHandler, + zodSchema: z.ZodType, + ) { this.#data = data; this.#versionedDataHandler = versionedDataHandler; + this.#zodSchema = zodSchema; } - public get rawData(): T { + public get rawData(): T | TJson { return this.#data; } @@ -45,6 +51,7 @@ export class CachedSerializer { encoding, this.#data, this.#versionedDataHandler, + this.#zodSchema, ); this.#cache.set(encoding, serialized); return serialized; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts index 19f499e437..9f2d4a08fe 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts @@ -1,6 +1,10 @@ import * as cbor from "cbor-x"; import type { Context as HonoContext, HonoRequest } from "hono"; import type { WSContext } from "hono/ws"; +import { + HttpActionRequestSchema, + HttpActionResponseSchema, +} from "@/actor/client-protocol-schema-json/mod"; import type { AnyConn } from "@/actor/conn/mod"; import { ActionContext } from "@/actor/contexts/action"; import * as errors from "@/actor/errors"; @@ -334,6 +338,7 @@ export async function handleAction( encoding, new Uint8Array(arrayBuffer), HTTP_ACTION_REQUEST_VERSIONED, + HttpActionRequestSchema, ); const actionArgs = cbor.decode(new Uint8Array(request.args)); @@ -371,6 +376,7 @@ export async function handleAction( encoding, responseData, HTTP_ACTION_RESPONSE_VERSIONED, + HttpActionResponseSchema, ); // TODO: Remvoe any, Hono is being a dumbass diff --git a/rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts b/rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts index 73f8c16e5f..006a7a4b35 100644 --- a/rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts +++ b/rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts @@ -2,6 +2,10 @@ import * as cbor from "cbor-x"; import invariant from "invariant"; import pRetry from "p-retry"; import type { CloseEvent } from "ws"; +import { + ToClientSchema, + ToServerSchema, +} from "@/actor/client-protocol-schema-json/mod"; import type { AnyActorDefinition } from "@/actor/definition"; import { inputDataToBuffer } from "@/actor/protocol/old"; import { type Encoding, jsonStringifyCompat } from "@/actor/protocol/serde"; @@ -693,6 +697,7 @@ enc this.#encoding, message, TO_SERVER_VERSIONED, + ToServerSchema, ); this.#websocket.send(messageSerialized); logger().trace({ @@ -743,6 +748,7 @@ enc this.#encoding, buffer, TO_CLIENT_VERSIONED, + ToClientSchema, ); } diff --git a/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts b/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts index 22eb2e5b66..de1f2c7e05 100644 --- a/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts +++ b/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts @@ -1,5 +1,9 @@ import * as cbor from "cbor-x"; import invariant from "invariant"; +import { + HttpActionRequestSchema, + HttpActionResponseSchema, +} from "@/actor/client-protocol-schema-json/mod"; import type { AnyActorDefinition } from "@/actor/definition"; import type { Encoding } from "@/actor/protocol/serde"; import { assertUnreachable } from "@/actor/utils"; @@ -122,6 +126,8 @@ export class ActorHandleRaw { signal: opts?.signal, requestVersionedDataHandler: HTTP_ACTION_REQUEST_VERSIONED, responseVersionedDataHandler: HTTP_ACTION_RESPONSE_VERSIONED, + requestZodSchema: HttpActionRequestSchema, + responseZodSchema: HttpActionResponseSchema, }); return cbor.decode(new Uint8Array(responseData.output)); diff --git a/rivetkit-typescript/packages/rivetkit/src/client/utils.ts b/rivetkit-typescript/packages/rivetkit/src/client/utils.ts index 4de17e909a..1f8d4349e6 100644 --- a/rivetkit-typescript/packages/rivetkit/src/client/utils.ts +++ b/rivetkit-typescript/packages/rivetkit/src/client/utils.ts @@ -1,5 +1,7 @@ import * as cbor from "cbor-x"; import invariant from "invariant"; +import type { z } from "zod"; +import { HttpResponseErrorSchema } from "@/actor/client-protocol-schema-json/mod"; import type { Encoding } from "@/actor/protocol/serde"; import { assertUnreachable } from "@/common/utils"; import type { VersionedDataHandler } from "@/common/versioned-data"; @@ -33,11 +35,16 @@ export function messageLength(message: WebSocketMessage): number { assertUnreachable(message); } -export interface HttpRequestOpts { +export interface HttpRequestOpts< + RequestBody, + ResponseBody, + RequestJson = RequestBody, + ResponseJson = ResponseBody, +> { method: string; url: string; headers: Record; - body?: RequestBody; + body?: RequestBody | RequestJson; encoding: Encoding; skipParseResponse?: boolean; signal?: AbortSignal; @@ -46,12 +53,18 @@ export interface HttpRequestOpts { responseVersionedDataHandler: | VersionedDataHandler | undefined; + requestZodSchema: z.ZodType; + responseZodSchema: z.ZodType; } export async function sendHttpRequest< RequestBody = unknown, ResponseBody = unknown, ->(opts: HttpRequestOpts): Promise { + RequestJson = RequestBody, + ResponseJson = ResponseBody, +>( + opts: HttpRequestOpts, +): Promise { logger().debug({ msg: "sending http request", url: opts.url, @@ -64,10 +77,11 @@ export async function sendHttpRequest< if (opts.method === "POST" || opts.method === "PUT") { invariant(opts.body !== undefined, "missing body"); contentType = contentTypeForEncoding(opts.encoding); - bodyData = serializeWithEncoding( + bodyData = serializeWithEncoding( opts.encoding, opts.body, opts.requestVersionedDataHandler, + opts.requestZodSchema, ); } @@ -108,6 +122,7 @@ export async function sendHttpRequest< opts.encoding, new Uint8Array(bufferResponse), HTTP_RESPONSE_ERROR_VERSIONED, + HttpResponseErrorSchema, ); } catch (error) { //logger().warn("failed to cleanly parse error, this is likely because a non-structured response is being served", { @@ -151,16 +166,17 @@ export async function sendHttpRequest< // Some requests don't need the success response to be parsed, so this can speed things up if (opts.skipParseResponse) { - return undefined as ResponseBody; + return undefined as ResponseBody | ResponseJson; } // Parse the response based on encoding try { const buffer = new Uint8Array(await response.arrayBuffer()); - return deserializeWithEncoding( + return deserializeWithEncoding( opts.encoding, buffer, opts.responseVersionedDataHandler, + opts.responseZodSchema, ); } catch (error) { throw new HttpRequestError(`Failed to parse response: ${error}`, { diff --git a/rivetkit-typescript/packages/rivetkit/src/common/router.ts b/rivetkit-typescript/packages/rivetkit/src/common/router.ts index 64e2f8fd15..ded59cc47a 100644 --- a/rivetkit-typescript/packages/rivetkit/src/common/router.ts +++ b/rivetkit-typescript/packages/rivetkit/src/common/router.ts @@ -1,5 +1,6 @@ import * as cbor from "cbor-x"; import type { Context as HonoContext, Next } from "hono"; +import { HttpResponseErrorSchema } from "@/actor/client-protocol-schema-json/mod"; import type { Encoding } from "@/actor/protocol/serde"; import { getRequestEncoding, @@ -80,6 +81,7 @@ export function handleRouteError(error: unknown, c: HonoContext) { : null, }, HTTP_RESPONSE_ERROR_VERSIONED, + HttpResponseErrorSchema, ); // TODO: Remove any diff --git a/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/api-utils.ts b/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/api-utils.ts index 1966270501..aa4c0e49c0 100644 --- a/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/api-utils.ts +++ b/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/api-utils.ts @@ -1,3 +1,4 @@ +import { z } from "zod"; import type { ClientConfig } from "@/client/config"; import { sendHttpRequest } from "@/client/utils"; import { combineUrlPath } from "@/utils"; @@ -51,5 +52,7 @@ export async function apiCall( skipParseResponse: false, requestVersionedDataHandler: undefined, responseVersionedDataHandler: undefined, + requestZodSchema: z.any() as z.ZodType, + responseZodSchema: z.any() as z.ZodType, }); } diff --git a/rivetkit-typescript/packages/rivetkit/src/serde.ts b/rivetkit-typescript/packages/rivetkit/src/serde.ts index b51ad5dbb7..c62ffcf70c 100644 --- a/rivetkit-typescript/packages/rivetkit/src/serde.ts +++ b/rivetkit-typescript/packages/rivetkit/src/serde.ts @@ -1,5 +1,6 @@ import * as cbor from "cbor-x"; import invariant from "invariant"; +import type { z } from "zod"; import { assertUnreachable } from "@/common/utils"; import type { VersionedDataHandler } from "@/common/versioned-data"; import type { Encoding } from "@/mod"; @@ -52,13 +53,15 @@ export function wsBinaryTypeForEncoding( } } -export function serializeWithEncoding( +export function serializeWithEncoding( encoding: Encoding, - value: T, + value: T | TJson, versionedDataHandler: VersionedDataHandler | undefined, + zodSchema: z.ZodType, ): Uint8Array | string { if (encoding === "json") { - return jsonStringifyCompat(value); + const validated = zodSchema.parse(value); + return jsonStringifyCompat(validated); } else if (encoding === "cbor") { return cbor.encode(value); } else if (encoding === "bare") { @@ -67,25 +70,28 @@ export function serializeWithEncoding( "VersionedDataHandler is required for 'bare' encoding", ); } - return versionedDataHandler.serializeWithEmbeddedVersion(value); + return versionedDataHandler.serializeWithEmbeddedVersion(value as T); } else { assertUnreachable(encoding); } } -export function deserializeWithEncoding( +export function deserializeWithEncoding( encoding: Encoding, buffer: Uint8Array | string, versionedDataHandler: VersionedDataHandler | undefined, -): T { + zodSchema: z.ZodType, +): T | TJson { if (encoding === "json") { + let parsed: unknown; if (typeof buffer === "string") { - return jsonParseCompat(buffer); + parsed = jsonParseCompat(buffer); } else { const decoder = new TextDecoder("utf-8"); const jsonString = decoder.decode(buffer); - return jsonParseCompat(jsonString); + parsed = jsonParseCompat(jsonString); } + return zodSchema.parse(parsed); } else if (encoding === "cbor") { invariant( typeof buffer !== "string",