diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts index 72707af2ef..1c7da41096 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts @@ -32,7 +32,7 @@ export interface ConnDriver { sendMessage?( actor: AnyActorInstance, conn: AnyConn, - message: CachedSerializer, + message: CachedSerializer, ): void; /** diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts index 614d7f3cff..8f855a790d 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts @@ -23,7 +23,7 @@ export function createWebSocketSocket( sendMessage: ( actor: AnyActorInstance, conn: AnyConn, - message: CachedSerializer, + message: CachedSerializer, ) => { if (websocket.readyState !== DriverReadyState.OPEN) { actor.rLog.warn({ diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts index fc35744165..6f87b5ce9c 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts @@ -1,7 +1,10 @@ import * as cbor from "cbor-x"; -import { ToClientSchema } from "@/actor/client-protocol-schema-json/mod"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; +import { + type ToClient as ToClientJson, + ToClientSchema, +} from "@/schemas/client-protocol-zod/mod"; import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils"; import type { AnyDatabaseProvider } from "../database"; import { @@ -161,7 +164,7 @@ export class Conn { return this.#stateManager.persistRaw; } - [CONN_SEND_MESSAGE_SYMBOL](message: CachedSerializer) { + [CONN_SEND_MESSAGE_SYMBOL](message: CachedSerializer) { if (this[CONN_DRIVER_SYMBOL]) { const driver = this[CONN_DRIVER_SYMBOL]; if (driver.sendMessage) { @@ -194,19 +197,32 @@ export class Conn { args, connId: this.id, }); + const eventData = { name: eventName, args }; this[CONN_SEND_MESSAGE_SYMBOL]( - new CachedSerializer( - { + new CachedSerializer( + eventData, + TO_CLIENT_VERSIONED, + ToClientSchema, + // JSON: args is the raw value (array of arguments) + (value): ToClientJson => ({ body: { - tag: "Event", + tag: "Event" as const, val: { - name: eventName, - args: bufferToArrayBuffer(cbor.encode(args)), + name: value.name, + args: value.args, }, }, - }, - TO_CLIENT_VERSIONED, - ToClientSchema, + }), + // BARE/CBOR: args needs to be CBOR-encoded to ArrayBuffer + (value): protocol.ToClient => ({ + body: { + tag: "Event" as const, + val: { + name: value.name, + args: bufferToArrayBuffer(cbor.encode(value.args)), + }, + }, + }), ), ); } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts index a8ce4a868f..7c9ee8ff26 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts @@ -1,7 +1,10 @@ import * as cbor from "cbor-x"; -import { ToClientSchema } from "@/actor/client-protocol-schema-json/mod"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; +import { + type ToClient as ToClientJson, + ToClientSchema, +} from "@/schemas/client-protocol-zod/mod"; import { bufferToArrayBuffer } from "@/utils"; import { CONN_PERSIST_SYMBOL, @@ -180,18 +183,31 @@ export class EventManager { } // Create serialized message - const toClientSerializer = new CachedSerializer( - { + const eventData = { name, args }; + const toClientSerializer = new CachedSerializer( + eventData, + TO_CLIENT_VERSIONED, + ToClientSchema, + // JSON: args is the raw value (array of arguments) + (value): ToClientJson => ({ body: { - tag: "Event", + tag: "Event" as const, val: { - name, - args: bufferToArrayBuffer(cbor.encode(args)), + name: value.name, + args: value.args, }, }, - }, - TO_CLIENT_VERSIONED, - ToClientSchema, + }), + // BARE/CBOR: args needs to be CBOR-encoded to ArrayBuffer + (value): protocol.ToClient => ({ + body: { + tag: "Event" as const, + val: { + name: value.name, + args: bufferToArrayBuffer(cbor.encode(value.args)), + }, + }, + }), ); // Send to all subscribers diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts index fdcff7bdff..ccadfcfa38 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts @@ -1,6 +1,5 @@ import * as cbor from "cbor-x"; import invariant from "invariant"; -import { ToClientSchema } from "@/actor/client-protocol-schema-json/mod"; import type { ActorKey } from "@/actor/mod"; import type { Client } from "@/client/client"; import { getBaseLogger, getIncludeTarget, type Logger } from "@/common/log"; @@ -11,6 +10,7 @@ import type { Registry } from "@/mod"; import { ACTOR_VERSIONED } from "@/schemas/actor-persist/versioned"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; +import { ToClientSchema } from "@/schemas/client-protocol-zod/mod"; import { EXTRA_ERROR_LOG, idToStr } from "@/utils"; import type { ActorConfig, InitContext } from "../config"; import type { ConnDriver } from "../conn/driver"; @@ -511,19 +511,26 @@ export class ActorInstance { await this.saveState({ immediate: true }); // Send init message + const initData = { actorId: this.id, connectionId: conn.id }; conn[CONN_SEND_MESSAGE_SYMBOL]( - new CachedSerializer( - { - body: { - tag: "Init", - val: { - actorId: this.id, - connectionId: conn.id, - }, - }, - }, + new CachedSerializer( + initData, TO_CLIENT_VERSIONED, ToClientSchema, + // JSON: identity conversion (no nested data to encode) + (value) => ({ + body: { + tag: "Init" as const, + val: value, + }, + }), + // BARE/CBOR: identity conversion (no nested data to encode) + (value) => ({ + body: { + tag: "Init" as const, + val: value, + }, + }), ), ); @@ -532,7 +539,17 @@ export class ActorInstance { // MARK: - Message Processing async processMessage( - message: protocol.ToServer, + message: { + body: + | { + tag: "ActionRequest"; + val: { id: bigint; name: string; args: unknown }; + } + | { + tag: "SubscriptionRequest"; + val: { eventName: string; subscribe: boolean }; + }; + }, conn: Conn, ) { await processMessage(message, this, conn, { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts index 37aed9dc91..0854a8d2e5 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts @@ -1,9 +1,5 @@ import * as cbor from "cbor-x"; import { z } from "zod"; -import { - ToClientSchema, - ToServerSchema, -} from "@/actor/client-protocol-schema-json/mod"; import type { AnyDatabaseProvider } from "@/actor/database"; import * as errors from "@/actor/errors"; import { @@ -17,6 +13,12 @@ import { TO_CLIENT_VERSIONED, TO_SERVER_VERSIONED, } from "@/schemas/client-protocol/versioned"; +import { + type ToClient as ToClientJson, + ToClientSchema, + type ToServer as ToServerJson, + ToServerSchema, +} from "@/schemas/client-protocol-zod/mod"; import { deserializeWithEncoding } from "@/serde"; import { assertUnreachable, bufferToArrayBuffer } from "../../utils"; import { CONN_SEND_MESSAGE_SYMBOL, type Conn } from "../conn/mod"; @@ -67,7 +69,17 @@ export async function inputDataToBuffer( export async function parseMessage( value: InputData, opts: MessageEventOpts, -): Promise { +): Promise<{ + body: + | { + tag: "ActionRequest"; + val: { id: bigint; name: string; args: unknown }; + } + | { + tag: "SubscriptionRequest"; + val: { eventName: string; subscribe: boolean }; + }; +}> { // Validate value length const length = getValueLength(value); if (length > opts.maxIncomingMessageSize) { @@ -90,6 +102,28 @@ export async function parseMessage( buffer, TO_SERVER_VERSIONED, ToServerSchema, + // JSON: values are already the correct type + (json: ToServerJson): any => json, + // BARE: need to decode ArrayBuffer fields back to unknown + (bare: protocol.ToServer): any => { + if (bare.body.tag === "ActionRequest") { + return { + body: { + tag: "ActionRequest", + val: { + id: bare.body.val.id, + name: bare.body.val.name, + args: cbor.decode( + new Uint8Array(bare.body.val.args), + ), + }, + }, + }; + } else { + // SubscriptionRequest has no ArrayBuffer fields + return bare; + } + }, ); } @@ -124,7 +158,17 @@ export async function processMessage< I, DB extends AnyDatabaseProvider, >( - message: protocol.ToServer, + message: { + body: + | { + tag: "ActionRequest"; + val: { id: bigint; name: string; args: unknown }; + } + | { + tag: "SubscriptionRequest"; + val: { eventName: string; subscribe: boolean }; + }; + }, actor: ActorInstance, conn: Conn, handler: ProcessMessageHandler, @@ -140,10 +184,9 @@ export async function processMessage< throw new errors.Unsupported("Action"); } - const { id, name, args: argsRaw } = message.body.val; + const { id, name, args } = message.body.val; actionId = id; actionName = name; - const args = cbor.decode(new Uint8Array(argsRaw)); actor.rLog.debug({ msg: "processing action request", @@ -155,7 +198,11 @@ export async function processMessage< // Process the action request and wait for the result // This will wait for async actions to complete - const output = await handler.onExecuteAction(ctx, name, args); + const output = await handler.onExecuteAction( + ctx, + name, + args as unknown[], + ); actor.rLog.debug({ msg: "sending action response", @@ -167,20 +214,30 @@ export async function processMessage< // Send the response back to the client conn[CONN_SEND_MESSAGE_SYMBOL]( - new CachedSerializer( - { + new CachedSerializer( + output, + TO_CLIENT_VERSIONED, + ToClientSchema, + // JSON: output is the raw value + (value): ToClientJson => ({ body: { - tag: "ActionResponse", + tag: "ActionResponse" as const, val: { id: id, - output: bufferToArrayBuffer( - cbor.encode(output), - ), + output: value, }, }, - }, - TO_CLIENT_VERSIONED, - ToClientSchema, + }), + // BARE/CBOR: output needs to be CBOR-encoded to ArrayBuffer + (value): protocol.ToClient => ({ + body: { + tag: "ActionResponse" as const, + val: { + id: id, + output: bufferToArrayBuffer(cbor.encode(value)), + }, + }, + }), ), ); @@ -236,24 +293,42 @@ export async function processMessage< }); // Build response + const errorData = { group, code, message, metadata, actionId }; conn[CONN_SEND_MESSAGE_SYMBOL]( - new CachedSerializer( - { + new CachedSerializer( + errorData, + TO_CLIENT_VERSIONED, + ToClientSchema, + // JSON: metadata is the raw value + (value): ToClientJson => ({ body: { - tag: "Error", + tag: "Error" as const, val: { - group, - code, - message, - metadata: bufferToArrayBuffer( - cbor.encode(metadata), - ), - actionId: actionId ?? null, + group: value.group, + code: value.code, + message: value.message, + metadata: value.metadata ?? null, + actionId: value.actionId ?? null, }, }, - }, - TO_CLIENT_VERSIONED, - ToClientSchema, + }), + // BARE/CBOR: metadata needs to be CBOR-encoded to ArrayBuffer + (value): protocol.ToClient => ({ + body: { + tag: "Error" as const, + val: { + group: value.group, + code: value.code, + message: value.message, + metadata: value.metadata + ? bufferToArrayBuffer( + cbor.encode(value.metadata), + ) + : null, + actionId: value.actionId ?? null, + }, + }, + }), ), ); diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/serde.ts b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/serde.ts index 157e6d3aa5..4e4ed85432 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/serde.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/serde.ts @@ -22,23 +22,29 @@ export type Encoding = z.infer; /** * Helper class that helps serialize data without re-serializing for the same encoding. */ -export class CachedSerializer { - #data: T | TJson; +export class CachedSerializer { + #data: T; #cache = new Map(); - #versionedDataHandler: VersionedDataHandler; + #versionedDataHandler: VersionedDataHandler; #zodSchema: z.ZodType; + #toJson: (value: T) => TJson; + #toBare: (value: T) => TBare; constructor( - data: T | TJson, - versionedDataHandler: VersionedDataHandler, + data: T, + versionedDataHandler: VersionedDataHandler, zodSchema: z.ZodType, + toJson: (value: T) => TJson, + toBare: (value: T) => TBare, ) { this.#data = data; this.#versionedDataHandler = versionedDataHandler; this.#zodSchema = zodSchema; + this.#toJson = toJson; + this.#toBare = toBare; } - public get rawData(): T | TJson { + public get rawData(): T { return this.#data; } @@ -52,6 +58,8 @@ export class CachedSerializer { this.#data, this.#versionedDataHandler, this.#zodSchema, + this.#toJson, + this.#toBare, ); this.#cache.set(encoding, serialized); return serialized; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts index 9f2d4a08fe..3b47feaa52 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts @@ -1,10 +1,6 @@ import * as cbor from "cbor-x"; import type { Context as HonoContext, HonoRequest } from "hono"; import type { WSContext } from "hono/ws"; -import { - HttpActionRequestSchema, - HttpActionResponseSchema, -} from "@/actor/client-protocol-schema-json/mod"; import type { AnyConn } from "@/actor/conn/mod"; import { ActionContext } from "@/actor/contexts/action"; import * as errors from "@/actor/errors"; @@ -31,6 +27,12 @@ import { HTTP_ACTION_REQUEST_VERSIONED, HTTP_ACTION_RESPONSE_VERSIONED, } from "@/schemas/client-protocol/versioned"; +import { + type HttpActionRequest as HttpActionRequestJson, + HttpActionRequestSchema, + type HttpActionResponse as HttpActionResponseJson, + HttpActionResponseSchema, +} from "@/schemas/client-protocol-zod/mod"; import { contentTypeForEncoding, deserializeWithEncoding, @@ -339,8 +341,13 @@ export async function handleAction( new Uint8Array(arrayBuffer), HTTP_ACTION_REQUEST_VERSIONED, HttpActionRequestSchema, + // JSON: args is already the decoded value (raw object/array) + (json: HttpActionRequestJson) => json.args, + // BARE/CBOR: args is ArrayBuffer that needs CBOR-decoding + (bare: protocol.HttpActionRequest) => + cbor.decode(new Uint8Array(bare.args)), ); - const actionArgs = cbor.decode(new Uint8Array(request.args)); + const actionArgs = request; // Invoke the action let actor: AnyActorInstance | undefined; @@ -369,14 +376,17 @@ export async function handleAction( } // Send response - const responseData: protocol.HttpActionResponse = { - output: bufferToArrayBuffer(cbor.encode(output)), - }; const serialized = serializeWithEncoding( encoding, - responseData, + output, HTTP_ACTION_RESPONSE_VERSIONED, HttpActionResponseSchema, + // JSON: output is the raw value (will be serialized by jsonStringifyCompat) + (value): HttpActionResponseJson => ({ output: value }), + // BARE/CBOR: output needs to be CBOR-encoded to ArrayBuffer + (value): protocol.HttpActionResponse => ({ + output: bufferToArrayBuffer(cbor.encode(value)), + }), ); // TODO: Remvoe any, Hono is being a dumbass diff --git a/rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts b/rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts index 006a7a4b35..93041edd9b 100644 --- a/rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts +++ b/rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts @@ -2,10 +2,6 @@ import * as cbor from "cbor-x"; import invariant from "invariant"; import pRetry from "p-retry"; import type { CloseEvent } from "ws"; -import { - ToClientSchema, - ToServerSchema, -} from "@/actor/client-protocol-schema-json/mod"; import type { AnyActorDefinition } from "@/actor/definition"; import { inputDataToBuffer } from "@/actor/protocol/old"; import { type Encoding, jsonStringifyCompat } from "@/actor/protocol/serde"; @@ -29,6 +25,12 @@ import { TO_CLIENT_VERSIONED, TO_SERVER_VERSIONED, } from "@/schemas/client-protocol/versioned"; +import { + type ToClient as ToClientJson, + ToClientSchema, + type ToServer as ToServerJson, + ToServerSchema, +} from "@/schemas/client-protocol-zod/mod"; import { deserializeWithEncoding, encodingIsBinary, @@ -53,7 +55,7 @@ import { interface ActionInFlight { name: string; - resolve: (response: protocol.ActionResponse) => void; + resolve: (response: { id: bigint; output: unknown }) => void; reject: (error: Error) => void; } @@ -99,7 +101,17 @@ export class ActorConnRaw { #actorId?: string; #connectionId?: string; - #messageQueue: protocol.ToServer[] = []; + #messageQueue: Array<{ + body: + | { + tag: "ActionRequest"; + val: { id: bigint; name: string; args: unknown }; + } + | { + tag: "SubscriptionRequest"; + val: { eventName: string; subscribe: boolean }; + }; + }> = []; #actionsInFlight = new Map(); // biome-ignore lint/suspicious/noExplicitAny: Unknown subscription type @@ -176,8 +188,10 @@ export class ActorConnRaw { const actionId = this.#actionIdCounter; this.#actionIdCounter += 1; - const { promise, resolve, reject } = - promiseWithResolvers(); + const { promise, resolve, reject } = promiseWithResolvers<{ + id: bigint; + output: unknown; + }>(); this.#actionsInFlight.set(actionId, { name: opts.name, resolve, @@ -190,10 +204,10 @@ export class ActorConnRaw { val: { id: BigInt(actionId), name: opts.name, - args: bufferToArrayBuffer(cbor.encode(opts.args)), + args: opts.args, }, }, - } satisfies protocol.ToServer); + }); // TODO: Throw error if disconnect is called @@ -203,7 +217,7 @@ export class ActorConnRaw { `Request ID ${actionId} does not match response ID ${responseId}`, ); - return cbor.decode(new Uint8Array(output)) as Response; + return output as Response; } /** @@ -553,16 +567,15 @@ enc return inFlight; } - #dispatchEvent(event: protocol.Event) { - const { name, args: argsRaw } = event; - const args = cbor.decode(new Uint8Array(argsRaw)); + #dispatchEvent(event: { name: string; args: unknown }) { + const { name, args } = event; const listeners = this.#eventSubscriptions.get(name); if (!listeners) return; // Create a new array to avoid issues with listeners being removed during iteration for (const listener of [...listeners]) { - listener.callback(...args); + listener.callback(...(args as unknown[])); // Remove if this was a one-time listener if (listener.once) { @@ -668,7 +681,20 @@ enc }; } - #sendMessage(message: protocol.ToServer, opts?: SendHttpMessageOpts) { + #sendMessage( + message: { + body: + | { + tag: "ActionRequest"; + val: { id: bigint; name: string; args: unknown }; + } + | { + tag: "SubscriptionRequest"; + val: { eventName: string; subscribe: boolean }; + }; + }, + opts?: SendHttpMessageOpts, + ) { if (this.#disposed) { throw new errors.ActorConnDisposed(); } @@ -698,6 +724,27 @@ enc message, TO_SERVER_VERSIONED, ToServerSchema, + // JSON: args is the raw value + (msg): ToServerJson => msg as ToServerJson, + // BARE: args needs to be CBOR-encoded to ArrayBuffer + (msg): protocol.ToServer => { + if (msg.body.tag === "ActionRequest") { + return { + body: { + tag: "ActionRequest", + val: { + id: msg.body.val.id, + name: msg.body.val.name, + args: bufferToArrayBuffer( + cbor.encode(msg.body.val.args), + ), + }, + }, + }; + } else { + return msg as protocol.ToServer; + } + }, ); this.#websocket.send(messageSerialized); logger().trace({ @@ -739,7 +786,22 @@ enc } } - async #parseMessage(data: ConnMessage): Promise { + async #parseMessage(data: ConnMessage): Promise<{ + body: + | { tag: "Init"; val: { actorId: string; connectionId: string } } + | { + tag: "Error"; + val: { + group: string; + code: string; + message: string; + metadata: unknown; + actionId: bigint | null; + }; + } + | { tag: "ActionResponse"; val: { id: bigint; output: unknown } } + | { tag: "Event"; val: { name: string; args: unknown } }; + }> { invariant(this.#websocket, "websocket must be defined"); const buffer = await inputDataToBuffer(data); @@ -749,6 +811,58 @@ enc buffer, TO_CLIENT_VERSIONED, ToClientSchema, + // JSON: values are already the correct type + (msg): ToClientJson => msg as ToClientJson, + // BARE: need to decode ArrayBuffer fields back to unknown + (msg): any => { + if (msg.body.tag === "Error") { + return { + body: { + tag: "Error", + val: { + group: msg.body.val.group, + code: msg.body.val.code, + message: msg.body.val.message, + metadata: msg.body.val.metadata + ? cbor.decode( + new Uint8Array( + msg.body.val.metadata, + ), + ) + : null, + actionId: msg.body.val.actionId, + }, + }, + }; + } else if (msg.body.tag === "ActionResponse") { + return { + body: { + tag: "ActionResponse", + val: { + id: msg.body.val.id, + output: cbor.decode( + new Uint8Array(msg.body.val.output), + ), + }, + }, + }; + } else if (msg.body.tag === "Event") { + return { + body: { + tag: "Event", + val: { + name: msg.body.val.name, + args: cbor.decode( + new Uint8Array(msg.body.val.args), + ), + }, + }, + }; + } else { + // Init has no ArrayBuffer fields + return msg; + } + }, ); } diff --git a/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts b/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts index de1f2c7e05..a0abcc8199 100644 --- a/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts +++ b/rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts @@ -1,9 +1,5 @@ import * as cbor from "cbor-x"; import invariant from "invariant"; -import { - HttpActionRequestSchema, - HttpActionResponseSchema, -} from "@/actor/client-protocol-schema-json/mod"; import type { AnyActorDefinition } from "@/actor/definition"; import type { Encoding } from "@/actor/protocol/serde"; import { assertUnreachable } from "@/actor/utils"; @@ -19,6 +15,12 @@ import { HTTP_ACTION_REQUEST_VERSIONED, HTTP_ACTION_RESPONSE_VERSIONED, } from "@/schemas/client-protocol/versioned"; +import { + type HttpActionRequest as HttpActionRequestJson, + HttpActionRequestSchema, + type HttpActionResponse as HttpActionResponseJson, + HttpActionResponseSchema, +} from "@/schemas/client-protocol-zod/mod"; import { bufferToArrayBuffer } from "@/utils"; import type { ActorDefinitionActions } from "./actor-common"; import { type ActorConn, ActorConnRaw } from "./actor-conn"; @@ -104,8 +106,12 @@ export class ActorHandleRaw { encoding: this.#encoding, }); const responseData = await sendHttpRequest< - protocol.HttpActionRequest, - protocol.HttpActionResponse + protocol.HttpActionRequest, // Bare type + protocol.HttpActionResponse, // Bare type + HttpActionRequestJson, // Json type + HttpActionResponseJson, // Json type + unknown[], // Request type (the args array) + Response // Response type (the output value) >({ url: `http://actor/action/${encodeURIComponent(opts.name)}`, method: "POST", @@ -115,9 +121,7 @@ export class ActorHandleRaw { ? { [HEADER_CONN_PARAMS]: JSON.stringify(this.#params) } : {}), }, - body: { - args: bufferToArrayBuffer(cbor.encode(opts.args)), - } satisfies protocol.HttpActionRequest, + body: opts.args, encoding: this.#encoding, customFetch: this.#driver.sendRequest.bind( this.#driver, @@ -128,9 +132,22 @@ export class ActorHandleRaw { responseVersionedDataHandler: HTTP_ACTION_RESPONSE_VERSIONED, requestZodSchema: HttpActionRequestSchema, responseZodSchema: HttpActionResponseSchema, + // JSON Request: args is the raw value + requestToJson: (args): HttpActionRequestJson => ({ + args, + }), + // BARE Request: args needs to be CBOR-encoded + requestToBare: (args): protocol.HttpActionRequest => ({ + args: bufferToArrayBuffer(cbor.encode(args)), + }), + // JSON Response: output is the raw value + responseFromJson: (json): Response => json.output as Response, + // BARE Response: output is ArrayBuffer that needs CBOR-decoding + responseFromBare: (bare): Response => + cbor.decode(new Uint8Array(bare.output)) as Response, }); - return cbor.decode(new Uint8Array(responseData.output)); + return responseData; } catch (err) { // Standardize to ClientActorError instead of the native backend error const { group, code, message, metadata } = deconstructError( diff --git a/rivetkit-typescript/packages/rivetkit/src/client/utils.ts b/rivetkit-typescript/packages/rivetkit/src/client/utils.ts index 1f8d4349e6..2aa66ac858 100644 --- a/rivetkit-typescript/packages/rivetkit/src/client/utils.ts +++ b/rivetkit-typescript/packages/rivetkit/src/client/utils.ts @@ -1,12 +1,15 @@ import * as cbor from "cbor-x"; import invariant from "invariant"; import type { z } from "zod"; -import { HttpResponseErrorSchema } from "@/actor/client-protocol-schema-json/mod"; import type { Encoding } from "@/actor/protocol/serde"; import { assertUnreachable } from "@/common/utils"; import type { VersionedDataHandler } from "@/common/versioned-data"; import type { HttpResponseError } from "@/schemas/client-protocol/mod"; import { HTTP_RESPONSE_ERROR_VERSIONED } from "@/schemas/client-protocol/versioned"; +import { + type HttpResponseError as HttpResponseErrorJson, + HttpResponseErrorSchema, +} from "@/schemas/client-protocol-zod/mod"; import { contentTypeForEncoding, deserializeWithEncoding, @@ -36,35 +39,50 @@ export function messageLength(message: WebSocketMessage): number { } export interface HttpRequestOpts< - RequestBody, - ResponseBody, - RequestJson = RequestBody, - ResponseJson = ResponseBody, + RequestBare, + ResponseBare, + RequestJson = RequestBare, + ResponseJson = ResponseBare, + Request = RequestBare, + Response = ResponseBare, > { method: string; url: string; headers: Record; - body?: RequestBody | RequestJson; + body?: Request; encoding: Encoding; skipParseResponse?: boolean; signal?: AbortSignal; - customFetch?: (req: Request) => Promise; - requestVersionedDataHandler: VersionedDataHandler | undefined; + customFetch?: (req: globalThis.Request) => Promise; + requestVersionedDataHandler: VersionedDataHandler | undefined; responseVersionedDataHandler: - | VersionedDataHandler + | VersionedDataHandler | undefined; requestZodSchema: z.ZodType; responseZodSchema: z.ZodType; + requestToJson: (value: Request) => RequestJson; + requestToBare: (value: Request) => RequestBare; + responseFromJson: (value: ResponseJson) => Response; + responseFromBare: (value: ResponseBare) => Response; } export async function sendHttpRequest< - RequestBody = unknown, - ResponseBody = unknown, - RequestJson = RequestBody, - ResponseJson = ResponseBody, + RequestBare = unknown, + ResponseBare = unknown, + RequestJson = RequestBare, + ResponseJson = ResponseBare, + Request = RequestBare, + Response = ResponseBare, >( - opts: HttpRequestOpts, -): Promise { + opts: HttpRequestOpts< + RequestBare, + ResponseBare, + RequestJson, + ResponseJson, + Request, + Response + >, +): Promise { logger().debug({ msg: "sending http request", url: opts.url, @@ -77,20 +95,22 @@ export async function sendHttpRequest< if (opts.method === "POST" || opts.method === "PUT") { invariant(opts.body !== undefined, "missing body"); contentType = contentTypeForEncoding(opts.encoding); - bodyData = serializeWithEncoding( + bodyData = serializeWithEncoding( opts.encoding, opts.body, opts.requestVersionedDataHandler, opts.requestZodSchema, + opts.requestToJson, + opts.requestToBare, ); } // Send request - let response: Response; + let response: globalThis.Response; try { // Make the HTTP request response = await (opts.customFetch ?? fetch)( - new Request(opts.url, { + new globalThis.Request(opts.url, { method: opts.method, headers: { ...opts.headers, @@ -116,13 +136,29 @@ export async function sendHttpRequest< if (!response.ok) { // Attempt to parse structured data const bufferResponse = await response.arrayBuffer(); - let responseData: HttpResponseError; + let responseData: { + group: string; + code: string; + message: string; + metadata: unknown; + }; try { responseData = deserializeWithEncoding( opts.encoding, new Uint8Array(bufferResponse), HTTP_RESPONSE_ERROR_VERSIONED, HttpResponseErrorSchema, + // JSON: metadata is already unknown + (json): HttpResponseErrorJson => json as HttpResponseErrorJson, + // BARE: decode ArrayBuffer metadata to unknown + (bare): any => ({ + group: bare.group, + code: bare.code, + message: bare.message, + metadata: bare.metadata + ? cbor.decode(new Uint8Array(bare.metadata)) + : null, + }), ); } catch (error) { //logger().warn("failed to cleanly parse error, this is likely because a non-structured response is being served", { @@ -147,36 +183,30 @@ export async function sendHttpRequest< } } - // Decode metadata based on encoding - only binary encodings have CBOR-encoded metadata - let decodedMetadata: unknown; - if (responseData.metadata && encodingIsBinary(opts.encoding)) { - decodedMetadata = cbor.decode( - new Uint8Array(responseData.metadata), - ); - } - // Throw structured error throw new ActorError( responseData.group, responseData.code, responseData.message, - decodedMetadata, + responseData.metadata, ); } // Some requests don't need the success response to be parsed, so this can speed things up if (opts.skipParseResponse) { - return undefined as ResponseBody | ResponseJson; + return undefined as Response; } // Parse the response based on encoding try { const buffer = new Uint8Array(await response.arrayBuffer()); - return deserializeWithEncoding( + return deserializeWithEncoding( opts.encoding, buffer, opts.responseVersionedDataHandler, opts.responseZodSchema, + opts.responseFromJson, + opts.responseFromBare, ); } catch (error) { throw new HttpRequestError(`Failed to parse response: ${error}`, { diff --git a/rivetkit-typescript/packages/rivetkit/src/common/router.ts b/rivetkit-typescript/packages/rivetkit/src/common/router.ts index ded59cc47a..99d3477af3 100644 --- a/rivetkit-typescript/packages/rivetkit/src/common/router.ts +++ b/rivetkit-typescript/packages/rivetkit/src/common/router.ts @@ -1,6 +1,5 @@ import * as cbor from "cbor-x"; import type { Context as HonoContext, Next } from "hono"; -import { HttpResponseErrorSchema } from "@/actor/client-protocol-schema-json/mod"; import type { Encoding } from "@/actor/protocol/serde"; import { getRequestEncoding, @@ -9,8 +8,12 @@ import { import { buildActorNames, type RegistryConfig } from "@/registry/config"; import type { RunnerConfig } from "@/registry/run-config"; import { getEndpoint } from "@/remote-manager-driver/api-utils"; -import { HttpResponseError } from "@/schemas/client-protocol/mod"; +import type * as protocol from "@/schemas/client-protocol/mod"; import { HTTP_RESPONSE_ERROR_VERSIONED } from "@/schemas/client-protocol/versioned"; +import { + type HttpResponseError as HttpResponseErrorJson, + HttpResponseErrorSchema, +} from "@/schemas/client-protocol-zod/mod"; import { encodingIsBinary, serializeWithEncoding } from "@/serde"; import { bufferToArrayBuffer, getEnvUniversal, VERSION } from "@/utils"; import { getLogger, type Logger } from "./log"; @@ -69,19 +72,28 @@ export function handleRouteError(error: unknown, c: HonoContext) { encoding = "json"; } + const errorData = { group, code, message, metadata }; const output = serializeWithEncoding( encoding, - { - group, - code, - message, - // TODO: Cannot serialize non-binary meta since it requires ArrayBuffer atm - metadata: encodingIsBinary(encoding) - ? bufferToArrayBuffer(cbor.encode(metadata)) - : null, - }, + errorData, HTTP_RESPONSE_ERROR_VERSIONED, HttpResponseErrorSchema, + // JSON: metadata is the raw value (will be serialized by jsonStringifyCompat) + (value): HttpResponseErrorJson => ({ + group: value.group, + code: value.code, + message: value.message, + metadata: value.metadata ?? null, + }), + // BARE/CBOR: metadata needs to be CBOR-encoded to ArrayBuffer + (value): protocol.HttpResponseError => ({ + group: value.group, + code: value.code, + message: value.message, + metadata: value.metadata + ? bufferToArrayBuffer(cbor.encode(value.metadata)) + : null, + }), ); // TODO: Remove any diff --git a/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/api-utils.ts b/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/api-utils.ts index aa4c0e49c0..7499e2f054 100644 --- a/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/api-utils.ts +++ b/rivetkit-typescript/packages/rivetkit/src/remote-manager-driver/api-utils.ts @@ -54,5 +54,10 @@ export async function apiCall( responseVersionedDataHandler: undefined, requestZodSchema: z.any() as z.ZodType, responseZodSchema: z.any() as z.ZodType, + // Identity conversions (passthrough for generic API calls) + requestToJson: (value) => value, + requestToBare: (value) => value, + responseFromJson: (value) => value, + responseFromBare: (value) => value, }); } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/client-protocol-schema-json/mod.ts b/rivetkit-typescript/packages/rivetkit/src/schemas/client-protocol-zod/mod.ts similarity index 88% rename from rivetkit-typescript/packages/rivetkit/src/actor/client-protocol-schema-json/mod.ts rename to rivetkit-typescript/packages/rivetkit/src/schemas/client-protocol-zod/mod.ts index 72b53ed0fe..ca70a30023 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/client-protocol-schema-json/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/schemas/client-protocol-zod/mod.ts @@ -1,8 +1,6 @@ import { z } from "zod"; -// Helper schemas for ArrayBuffer handling in JSON -const ArrayBufferSchema = z.instanceof(ArrayBuffer); -const OptionalArrayBufferSchema = ArrayBufferSchema.nullable(); +// Helper schemas const UintSchema = z.bigint(); const OptionalUintSchema = UintSchema.nullable(); @@ -17,20 +15,20 @@ export const ErrorSchema = z.object({ group: z.string(), code: z.string(), message: z.string(), - metadata: OptionalArrayBufferSchema, + metadata: z.unknown().nullable(), actionId: OptionalUintSchema, }); export type Error = z.infer; export const ActionResponseSchema = z.object({ id: UintSchema, - output: ArrayBufferSchema, + output: z.unknown(), }); export type ActionResponse = z.infer; export const EventSchema = z.object({ name: z.string(), - args: ArrayBufferSchema, + args: z.unknown(), }); export type Event = z.infer; @@ -51,7 +49,7 @@ export type ToClient = z.infer; export const ActionRequestSchema = z.object({ id: UintSchema, name: z.string(), - args: ArrayBufferSchema, + args: z.unknown(), }); export type ActionRequest = z.infer; @@ -77,12 +75,12 @@ export type ToServer = z.infer; // MARK: HTTP Action export const HttpActionRequestSchema = z.object({ - args: ArrayBufferSchema, + args: z.unknown(), }); export type HttpActionRequest = z.infer; export const HttpActionResponseSchema = z.object({ - output: ArrayBufferSchema, + output: z.unknown(), }); export type HttpActionResponse = z.infer; @@ -91,7 +89,7 @@ export const HttpResponseErrorSchema = z.object({ group: z.string(), code: z.string(), message: z.string(), - metadata: OptionalArrayBufferSchema, + metadata: z.unknown().nullable(), }); export type HttpResponseError = z.infer; diff --git a/rivetkit-typescript/packages/rivetkit/src/serde.ts b/rivetkit-typescript/packages/rivetkit/src/serde.ts index c62ffcf70c..d006f2e115 100644 --- a/rivetkit-typescript/packages/rivetkit/src/serde.ts +++ b/rivetkit-typescript/packages/rivetkit/src/serde.ts @@ -53,35 +53,43 @@ export function wsBinaryTypeForEncoding( } } -export function serializeWithEncoding( +export function serializeWithEncoding( encoding: Encoding, - value: T | TJson, - versionedDataHandler: VersionedDataHandler | undefined, + value: T, + versionedDataHandler: VersionedDataHandler | undefined, zodSchema: z.ZodType, + toJson: (value: T) => TJson, + toBare: (value: T) => TBare, ): Uint8Array | string { if (encoding === "json") { - const validated = zodSchema.parse(value); + const jsonValue = toJson(value); + const validated = zodSchema.parse(jsonValue); return jsonStringifyCompat(validated); } else if (encoding === "cbor") { - return cbor.encode(value); + const jsonValue = toJson(value); + const validated = zodSchema.parse(jsonValue); + return cbor.encode(validated); } else if (encoding === "bare") { if (!versionedDataHandler) { throw new Error( "VersionedDataHandler is required for 'bare' encoding", ); } - return versionedDataHandler.serializeWithEmbeddedVersion(value as T); + const bareValue = toBare(value); + return versionedDataHandler.serializeWithEmbeddedVersion(bareValue); } else { assertUnreachable(encoding); } } -export function deserializeWithEncoding( +export function deserializeWithEncoding( encoding: Encoding, buffer: Uint8Array | string, - versionedDataHandler: VersionedDataHandler | undefined, + versionedDataHandler: VersionedDataHandler | undefined, zodSchema: z.ZodType, -): T | TJson { + fromJson: (value: TJson) => T, + fromBare: (value: TBare) => T, +): T { if (encoding === "json") { let parsed: unknown; if (typeof buffer === "string") { @@ -91,13 +99,19 @@ export function deserializeWithEncoding( const jsonString = decoder.decode(buffer); parsed = jsonParseCompat(jsonString); } - return zodSchema.parse(parsed); + const validated = zodSchema.parse(parsed); + return fromJson(validated); } else if (encoding === "cbor") { invariant( typeof buffer !== "string", "buffer cannot be string for cbor encoding", ); - return cbor.decode(buffer); + // Decode CBOR to get JavaScript values (similar to JSON.parse) + const decoded: unknown = cbor.decode(buffer); + // Validate with Zod schema (CBOR produces same structure as JSON) + const validated = zodSchema.parse(decoded); + // CBOR decoding produces JS objects, use fromJson + return fromJson(validated); } else if (encoding === "bare") { invariant( typeof buffer !== "string", @@ -108,7 +122,9 @@ export function deserializeWithEncoding( "VersionedDataHandler is required for 'bare' encoding", ); } - return versionedDataHandler.deserializeWithEmbeddedVersion(buffer); + const bareValue = + versionedDataHandler.deserializeWithEmbeddedVersion(buffer); + return fromBare(bareValue); } else { assertUnreachable(encoding); }