Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import { z } from "zod";

// Helper schemas for ArrayBuffer handling in JSON
const ArrayBufferSchema = z.instanceof(ArrayBuffer);
const OptionalArrayBufferSchema = ArrayBufferSchema.nullable();
const UintSchema = z.bigint();
const OptionalUintSchema = UintSchema.nullable();

// MARK: Message To Client
export const InitSchema = z.object({
actorId: z.string(),
connectionId: z.string(),
});
export type Init = z.infer<typeof InitSchema>;

export const ErrorSchema = z.object({
group: z.string(),
code: z.string(),
message: z.string(),
metadata: OptionalArrayBufferSchema,
actionId: OptionalUintSchema,
});
export type Error = z.infer<typeof ErrorSchema>;

export const ActionResponseSchema = z.object({
id: UintSchema,
output: ArrayBufferSchema,
});
export type ActionResponse = z.infer<typeof ActionResponseSchema>;

export const EventSchema = z.object({
name: z.string(),
args: ArrayBufferSchema,
});
export type Event = z.infer<typeof EventSchema>;

export const ToClientBodySchema = z.discriminatedUnion("tag", [
z.object({ tag: z.literal("Init"), val: InitSchema }),
z.object({ tag: z.literal("Error"), val: ErrorSchema }),
z.object({ tag: z.literal("ActionResponse"), val: ActionResponseSchema }),
z.object({ tag: z.literal("Event"), val: EventSchema }),
]);
export type ToClientBody = z.infer<typeof ToClientBodySchema>;

export const ToClientSchema = z.object({
body: ToClientBodySchema,
});
export type ToClient = z.infer<typeof ToClientSchema>;

// MARK: Message To Server
export const ActionRequestSchema = z.object({
id: UintSchema,
name: z.string(),
args: ArrayBufferSchema,
});
export type ActionRequest = z.infer<typeof ActionRequestSchema>;

export const SubscriptionRequestSchema = z.object({
eventName: z.string(),
subscribe: z.boolean(),
});
export type SubscriptionRequest = z.infer<typeof SubscriptionRequestSchema>;

export const ToServerBodySchema = z.discriminatedUnion("tag", [
z.object({ tag: z.literal("ActionRequest"), val: ActionRequestSchema }),
z.object({
tag: z.literal("SubscriptionRequest"),
val: SubscriptionRequestSchema,
}),
]);
export type ToServerBody = z.infer<typeof ToServerBodySchema>;

export const ToServerSchema = z.object({
body: ToServerBodySchema,
});
export type ToServer = z.infer<typeof ToServerSchema>;

// MARK: HTTP Action
export const HttpActionRequestSchema = z.object({
args: ArrayBufferSchema,
});
export type HttpActionRequest = z.infer<typeof HttpActionRequestSchema>;

export const HttpActionResponseSchema = z.object({
output: ArrayBufferSchema,
});
export type HttpActionResponse = z.infer<typeof HttpActionResponseSchema>;

// MARK: HTTP Error
export const HttpResponseErrorSchema = z.object({
group: z.string(),
code: z.string(),
message: z.string(),
metadata: OptionalArrayBufferSchema,
});
export type HttpResponseError = z.infer<typeof HttpResponseErrorSchema>;

// MARK: HTTP Resolve
export const HttpResolveRequestSchema = z.null();
export type HttpResolveRequest = z.infer<typeof HttpResolveRequestSchema>;

export const HttpResolveResponseSchema = z.object({
actorId: z.string(),
});
export type HttpResolveResponse = z.infer<typeof HttpResolveResponseSchema>;
2 changes: 2 additions & 0 deletions rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import * as cbor from "cbor-x";
import { ToClientSchema } from "@/actor/client-protocol-schema-json/mod";
import type * as protocol from "@/schemas/client-protocol/mod";
import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned";
import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils";
Expand Down Expand Up @@ -205,6 +206,7 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
},
},
TO_CLIENT_VERSIONED,
ToClientSchema,
),
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import * as cbor from "cbor-x";
import { ToClientSchema } from "@/actor/client-protocol-schema-json/mod";
import type * as protocol from "@/schemas/client-protocol/mod";
import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned";
import { bufferToArrayBuffer } from "@/utils";
Expand Down Expand Up @@ -190,6 +191,7 @@ export class EventManager<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
},
},
TO_CLIENT_VERSIONED,
ToClientSchema,
);

// Send to all subscribers
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import * as cbor from "cbor-x";
import invariant from "invariant";
import { ToClientSchema } from "@/actor/client-protocol-schema-json/mod";
import type { ActorKey } from "@/actor/mod";
import type { Client } from "@/client/client";
import { getBaseLogger, getIncludeTarget, type Logger } from "@/common/log";
Expand Down Expand Up @@ -522,6 +523,7 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
},
},
TO_CLIENT_VERSIONED,
ToClientSchema,
),
);

Expand Down
13 changes: 12 additions & 1 deletion rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import * as cbor from "cbor-x";
import { z } from "zod";
import {
ToClientSchema,
ToServerSchema,
} from "@/actor/client-protocol-schema-json/mod";
import type { AnyDatabaseProvider } from "@/actor/database";
import * as errors from "@/actor/errors";
import {
Expand Down Expand Up @@ -81,7 +85,12 @@ export async function parseMessage(
}

// Deserialize message
return deserializeWithEncoding(opts.encoding, buffer, TO_SERVER_VERSIONED);
return deserializeWithEncoding(
opts.encoding,
buffer,
TO_SERVER_VERSIONED,
ToServerSchema,
);
}

export interface ProcessMessageHandler<
Expand Down Expand Up @@ -171,6 +180,7 @@ export async function processMessage<
},
},
TO_CLIENT_VERSIONED,
ToClientSchema,
),
);

Expand Down Expand Up @@ -243,6 +253,7 @@ export async function processMessage<
},
},
TO_CLIENT_VERSIONED,
ToClientSchema,
),
);

Expand Down
15 changes: 11 additions & 4 deletions rivetkit-typescript/packages/rivetkit/src/actor/protocol/serde.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,23 @@ export type Encoding = z.infer<typeof EncodingSchema>;
/**
* Helper class that helps serialize data without re-serializing for the same encoding.
*/
export class CachedSerializer<T> {
#data: T;
export class CachedSerializer<T, TJson = T> {
#data: T | TJson;
#cache = new Map<Encoding, OutputData>();
#versionedDataHandler: VersionedDataHandler<T>;
#zodSchema: z.ZodType<TJson>;

constructor(data: T, versionedDataHandler: VersionedDataHandler<T>) {
constructor(
data: T | TJson,
versionedDataHandler: VersionedDataHandler<T>,
zodSchema: z.ZodType<TJson>,
) {
this.#data = data;
this.#versionedDataHandler = versionedDataHandler;
this.#zodSchema = zodSchema;
}

public get rawData(): T {
public get rawData(): T | TJson {
return this.#data;
}

Expand All @@ -45,6 +51,7 @@ export class CachedSerializer<T> {
encoding,
this.#data,
this.#versionedDataHandler,
this.#zodSchema,
);
this.#cache.set(encoding, serialized);
return serialized;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import * as cbor from "cbor-x";
import type { Context as HonoContext, HonoRequest } from "hono";
import type { WSContext } from "hono/ws";
import {
HttpActionRequestSchema,
HttpActionResponseSchema,
} from "@/actor/client-protocol-schema-json/mod";
import type { AnyConn } from "@/actor/conn/mod";
import { ActionContext } from "@/actor/contexts/action";
import * as errors from "@/actor/errors";
Expand Down Expand Up @@ -334,6 +338,7 @@ export async function handleAction(
encoding,
new Uint8Array(arrayBuffer),
HTTP_ACTION_REQUEST_VERSIONED,
HttpActionRequestSchema,
);
const actionArgs = cbor.decode(new Uint8Array(request.args));

Expand Down Expand Up @@ -371,6 +376,7 @@ export async function handleAction(
encoding,
responseData,
HTTP_ACTION_RESPONSE_VERSIONED,
HttpActionResponseSchema,
);

// TODO: Remvoe any, Hono is being a dumbass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ import * as cbor from "cbor-x";
import invariant from "invariant";
import pRetry from "p-retry";
import type { CloseEvent } from "ws";
import {
ToClientSchema,
ToServerSchema,
} from "@/actor/client-protocol-schema-json/mod";
import type { AnyActorDefinition } from "@/actor/definition";
import { inputDataToBuffer } from "@/actor/protocol/old";
import { type Encoding, jsonStringifyCompat } from "@/actor/protocol/serde";
Expand Down Expand Up @@ -693,6 +697,7 @@ enc
this.#encoding,
message,
TO_SERVER_VERSIONED,
ToServerSchema,
);
this.#websocket.send(messageSerialized);
logger().trace({
Expand Down Expand Up @@ -743,6 +748,7 @@ enc
this.#encoding,
buffer,
TO_CLIENT_VERSIONED,
ToClientSchema,
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import * as cbor from "cbor-x";
import invariant from "invariant";
import {
HttpActionRequestSchema,
HttpActionResponseSchema,
} from "@/actor/client-protocol-schema-json/mod";
import type { AnyActorDefinition } from "@/actor/definition";
import type { Encoding } from "@/actor/protocol/serde";
import { assertUnreachable } from "@/actor/utils";
Expand Down Expand Up @@ -122,6 +126,8 @@ export class ActorHandleRaw {
signal: opts?.signal,
requestVersionedDataHandler: HTTP_ACTION_REQUEST_VERSIONED,
responseVersionedDataHandler: HTTP_ACTION_RESPONSE_VERSIONED,
requestZodSchema: HttpActionRequestSchema,
responseZodSchema: HttpActionResponseSchema,
});

return cbor.decode(new Uint8Array(responseData.output));
Expand Down
28 changes: 22 additions & 6 deletions rivetkit-typescript/packages/rivetkit/src/client/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import * as cbor from "cbor-x";
import invariant from "invariant";
import type { z } from "zod";
import { HttpResponseErrorSchema } from "@/actor/client-protocol-schema-json/mod";
import type { Encoding } from "@/actor/protocol/serde";
import { assertUnreachable } from "@/common/utils";
import type { VersionedDataHandler } from "@/common/versioned-data";
Expand Down Expand Up @@ -33,11 +35,16 @@ export function messageLength(message: WebSocketMessage): number {
assertUnreachable(message);
}

export interface HttpRequestOpts<RequestBody, ResponseBody> {
export interface HttpRequestOpts<
RequestBody,
ResponseBody,
RequestJson = RequestBody,
ResponseJson = ResponseBody,
> {
method: string;
url: string;
headers: Record<string, string>;
body?: RequestBody;
body?: RequestBody | RequestJson;
encoding: Encoding;
skipParseResponse?: boolean;
signal?: AbortSignal;
Expand All @@ -46,12 +53,18 @@ export interface HttpRequestOpts<RequestBody, ResponseBody> {
responseVersionedDataHandler:
| VersionedDataHandler<ResponseBody>
| undefined;
requestZodSchema: z.ZodType<RequestJson>;
responseZodSchema: z.ZodType<ResponseJson>;
}

export async function sendHttpRequest<
RequestBody = unknown,
ResponseBody = unknown,
>(opts: HttpRequestOpts<RequestBody, ResponseBody>): Promise<ResponseBody> {
RequestJson = RequestBody,
ResponseJson = ResponseBody,
>(
opts: HttpRequestOpts<RequestBody, ResponseBody, RequestJson, ResponseJson>,
): Promise<ResponseBody | ResponseJson> {
logger().debug({
msg: "sending http request",
url: opts.url,
Expand All @@ -64,10 +77,11 @@ export async function sendHttpRequest<
if (opts.method === "POST" || opts.method === "PUT") {
invariant(opts.body !== undefined, "missing body");
contentType = contentTypeForEncoding(opts.encoding);
bodyData = serializeWithEncoding<RequestBody>(
bodyData = serializeWithEncoding<RequestBody, RequestJson>(
opts.encoding,
opts.body,
opts.requestVersionedDataHandler,
opts.requestZodSchema,
);
}

Expand Down Expand Up @@ -108,6 +122,7 @@ export async function sendHttpRequest<
opts.encoding,
new Uint8Array(bufferResponse),
HTTP_RESPONSE_ERROR_VERSIONED,
HttpResponseErrorSchema,
);
} catch (error) {
//logger().warn("failed to cleanly parse error, this is likely because a non-structured response is being served", {
Expand Down Expand Up @@ -151,16 +166,17 @@ export async function sendHttpRequest<

// Some requests don't need the success response to be parsed, so this can speed things up
if (opts.skipParseResponse) {
return undefined as ResponseBody;
return undefined as ResponseBody | ResponseJson;
}

// Parse the response based on encoding
try {
const buffer = new Uint8Array(await response.arrayBuffer());
return deserializeWithEncoding(
return deserializeWithEncoding<ResponseBody, ResponseJson>(
opts.encoding,
buffer,
opts.responseVersionedDataHandler,
opts.responseZodSchema,
);
} catch (error) {
throw new HttpRequestError(`Failed to parse response: ${error}`, {
Expand Down
Loading
Loading