Skip to content

Commit be3dc7d

Browse files
committed
chore(rivetkit): add required zod validation for json encoding
1 parent 51507e2 commit be3dc7d

File tree

13 files changed

+193
-19
lines changed

13 files changed

+193
-19
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import { z } from "zod";
2+
3+
// Helper schemas for ArrayBuffer handling in JSON
4+
const ArrayBufferSchema = z.instanceof(ArrayBuffer);
5+
const OptionalArrayBufferSchema = ArrayBufferSchema.nullable();
6+
const UintSchema = z.bigint();
7+
const OptionalUintSchema = UintSchema.nullable();
8+
9+
// MARK: Message To Client
10+
export const InitSchema = z.object({
11+
actorId: z.string(),
12+
connectionId: z.string(),
13+
});
14+
export type Init = z.infer<typeof InitSchema>;
15+
16+
export const ErrorSchema = z.object({
17+
group: z.string(),
18+
code: z.string(),
19+
message: z.string(),
20+
metadata: OptionalArrayBufferSchema,
21+
actionId: OptionalUintSchema,
22+
});
23+
export type Error = z.infer<typeof ErrorSchema>;
24+
25+
export const ActionResponseSchema = z.object({
26+
id: UintSchema,
27+
output: ArrayBufferSchema,
28+
});
29+
export type ActionResponse = z.infer<typeof ActionResponseSchema>;
30+
31+
export const EventSchema = z.object({
32+
name: z.string(),
33+
args: ArrayBufferSchema,
34+
});
35+
export type Event = z.infer<typeof EventSchema>;
36+
37+
export const ToClientBodySchema = z.discriminatedUnion("tag", [
38+
z.object({ tag: z.literal("Init"), val: InitSchema }),
39+
z.object({ tag: z.literal("Error"), val: ErrorSchema }),
40+
z.object({ tag: z.literal("ActionResponse"), val: ActionResponseSchema }),
41+
z.object({ tag: z.literal("Event"), val: EventSchema }),
42+
]);
43+
export type ToClientBody = z.infer<typeof ToClientBodySchema>;
44+
45+
export const ToClientSchema = z.object({
46+
body: ToClientBodySchema,
47+
});
48+
export type ToClient = z.infer<typeof ToClientSchema>;
49+
50+
// MARK: Message To Server
51+
export const ActionRequestSchema = z.object({
52+
id: UintSchema,
53+
name: z.string(),
54+
args: ArrayBufferSchema,
55+
});
56+
export type ActionRequest = z.infer<typeof ActionRequestSchema>;
57+
58+
export const SubscriptionRequestSchema = z.object({
59+
eventName: z.string(),
60+
subscribe: z.boolean(),
61+
});
62+
export type SubscriptionRequest = z.infer<typeof SubscriptionRequestSchema>;
63+
64+
export const ToServerBodySchema = z.discriminatedUnion("tag", [
65+
z.object({ tag: z.literal("ActionRequest"), val: ActionRequestSchema }),
66+
z.object({
67+
tag: z.literal("SubscriptionRequest"),
68+
val: SubscriptionRequestSchema,
69+
}),
70+
]);
71+
export type ToServerBody = z.infer<typeof ToServerBodySchema>;
72+
73+
export const ToServerSchema = z.object({
74+
body: ToServerBodySchema,
75+
});
76+
export type ToServer = z.infer<typeof ToServerSchema>;
77+
78+
// MARK: HTTP Action
79+
export const HttpActionRequestSchema = z.object({
80+
args: ArrayBufferSchema,
81+
});
82+
export type HttpActionRequest = z.infer<typeof HttpActionRequestSchema>;
83+
84+
export const HttpActionResponseSchema = z.object({
85+
output: ArrayBufferSchema,
86+
});
87+
export type HttpActionResponse = z.infer<typeof HttpActionResponseSchema>;
88+
89+
// MARK: HTTP Error
90+
export const HttpResponseErrorSchema = z.object({
91+
group: z.string(),
92+
code: z.string(),
93+
message: z.string(),
94+
metadata: OptionalArrayBufferSchema,
95+
});
96+
export type HttpResponseError = z.infer<typeof HttpResponseErrorSchema>;
97+
98+
// MARK: HTTP Resolve
99+
export const HttpResolveRequestSchema = z.null();
100+
export type HttpResolveRequest = z.infer<typeof HttpResolveRequestSchema>;
101+
102+
export const HttpResolveResponseSchema = z.object({
103+
actorId: z.string(),
104+
});
105+
export type HttpResolveResponse = z.infer<typeof HttpResolveResponseSchema>;

rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import * as cbor from "cbor-x";
2+
import { ToClientSchema } from "@/actor/client-protocol-schema-json/mod";
23
import type * as protocol from "@/schemas/client-protocol/mod";
34
import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned";
45
import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils";
@@ -205,6 +206,7 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
205206
},
206207
},
207208
TO_CLIENT_VERSIONED,
209+
ToClientSchema,
208210
),
209211
);
210212
}

rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import * as cbor from "cbor-x";
2+
import { ToClientSchema } from "@/actor/client-protocol-schema-json/mod";
23
import type * as protocol from "@/schemas/client-protocol/mod";
34
import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned";
45
import { bufferToArrayBuffer } from "@/utils";
@@ -190,6 +191,7 @@ export class EventManager<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
190191
},
191192
},
192193
TO_CLIENT_VERSIONED,
194+
ToClientSchema,
193195
);
194196

195197
// Send to all subscribers

rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import * as cbor from "cbor-x";
22
import invariant from "invariant";
3+
import { ToClientSchema } from "@/actor/client-protocol-schema-json/mod";
34
import type { ActorKey } from "@/actor/mod";
45
import type { Client } from "@/client/client";
56
import { getBaseLogger, getIncludeTarget, type Logger } from "@/common/log";
@@ -522,6 +523,7 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
522523
},
523524
},
524525
TO_CLIENT_VERSIONED,
526+
ToClientSchema,
525527
),
526528
);
527529

rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import * as cbor from "cbor-x";
22
import { z } from "zod";
3+
import {
4+
ToClientSchema,
5+
ToServerSchema,
6+
} from "@/actor/client-protocol-schema-json/mod";
37
import type { AnyDatabaseProvider } from "@/actor/database";
48
import * as errors from "@/actor/errors";
59
import {
@@ -81,7 +85,12 @@ export async function parseMessage(
8185
}
8286

8387
// Deserialize message
84-
return deserializeWithEncoding(opts.encoding, buffer, TO_SERVER_VERSIONED);
88+
return deserializeWithEncoding(
89+
opts.encoding,
90+
buffer,
91+
TO_SERVER_VERSIONED,
92+
ToServerSchema,
93+
);
8594
}
8695

8796
export interface ProcessMessageHandler<
@@ -171,6 +180,7 @@ export async function processMessage<
171180
},
172181
},
173182
TO_CLIENT_VERSIONED,
183+
ToClientSchema,
174184
),
175185
);
176186

@@ -243,6 +253,7 @@ export async function processMessage<
243253
},
244254
},
245255
TO_CLIENT_VERSIONED,
256+
ToClientSchema,
246257
),
247258
);
248259

rivetkit-typescript/packages/rivetkit/src/actor/protocol/serde.ts

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,23 @@ export type Encoding = z.infer<typeof EncodingSchema>;
2222
/**
2323
* Helper class that helps serialize data without re-serializing for the same encoding.
2424
*/
25-
export class CachedSerializer<T> {
26-
#data: T;
25+
export class CachedSerializer<T, TJson = T> {
26+
#data: T | TJson;
2727
#cache = new Map<Encoding, OutputData>();
2828
#versionedDataHandler: VersionedDataHandler<T>;
29+
#zodSchema: z.ZodType<TJson>;
2930

30-
constructor(data: T, versionedDataHandler: VersionedDataHandler<T>) {
31+
constructor(
32+
data: T | TJson,
33+
versionedDataHandler: VersionedDataHandler<T>,
34+
zodSchema: z.ZodType<TJson>,
35+
) {
3136
this.#data = data;
3237
this.#versionedDataHandler = versionedDataHandler;
38+
this.#zodSchema = zodSchema;
3339
}
3440

35-
public get rawData(): T {
41+
public get rawData(): T | TJson {
3642
return this.#data;
3743
}
3844

@@ -45,6 +51,7 @@ export class CachedSerializer<T> {
4551
encoding,
4652
this.#data,
4753
this.#versionedDataHandler,
54+
this.#zodSchema,
4855
);
4956
this.#cache.set(encoding, serialized);
5057
return serialized;

rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import * as cbor from "cbor-x";
22
import type { Context as HonoContext, HonoRequest } from "hono";
33
import type { WSContext } from "hono/ws";
4+
import {
5+
HttpActionRequestSchema,
6+
HttpActionResponseSchema,
7+
} from "@/actor/client-protocol-schema-json/mod";
48
import type { AnyConn } from "@/actor/conn/mod";
59
import { ActionContext } from "@/actor/contexts/action";
610
import * as errors from "@/actor/errors";
@@ -334,6 +338,7 @@ export async function handleAction(
334338
encoding,
335339
new Uint8Array(arrayBuffer),
336340
HTTP_ACTION_REQUEST_VERSIONED,
341+
HttpActionRequestSchema,
337342
);
338343
const actionArgs = cbor.decode(new Uint8Array(request.args));
339344

@@ -371,6 +376,7 @@ export async function handleAction(
371376
encoding,
372377
responseData,
373378
HTTP_ACTION_RESPONSE_VERSIONED,
379+
HttpActionResponseSchema,
374380
);
375381

376382
// TODO: Remvoe any, Hono is being a dumbass

rivetkit-typescript/packages/rivetkit/src/client/actor-conn.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ import * as cbor from "cbor-x";
22
import invariant from "invariant";
33
import pRetry from "p-retry";
44
import type { CloseEvent } from "ws";
5+
import {
6+
ToClientSchema,
7+
ToServerSchema,
8+
} from "@/actor/client-protocol-schema-json/mod";
59
import type { AnyActorDefinition } from "@/actor/definition";
610
import { inputDataToBuffer } from "@/actor/protocol/old";
711
import { type Encoding, jsonStringifyCompat } from "@/actor/protocol/serde";
@@ -693,6 +697,7 @@ enc
693697
this.#encoding,
694698
message,
695699
TO_SERVER_VERSIONED,
700+
ToServerSchema,
696701
);
697702
this.#websocket.send(messageSerialized);
698703
logger().trace({
@@ -743,6 +748,7 @@ enc
743748
this.#encoding,
744749
buffer,
745750
TO_CLIENT_VERSIONED,
751+
ToClientSchema,
746752
);
747753
}
748754

rivetkit-typescript/packages/rivetkit/src/client/actor-handle.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import * as cbor from "cbor-x";
22
import invariant from "invariant";
3+
import {
4+
HttpActionRequestSchema,
5+
HttpActionResponseSchema,
6+
} from "@/actor/client-protocol-schema-json/mod";
37
import type { AnyActorDefinition } from "@/actor/definition";
48
import type { Encoding } from "@/actor/protocol/serde";
59
import { assertUnreachable } from "@/actor/utils";
@@ -122,6 +126,8 @@ export class ActorHandleRaw {
122126
signal: opts?.signal,
123127
requestVersionedDataHandler: HTTP_ACTION_REQUEST_VERSIONED,
124128
responseVersionedDataHandler: HTTP_ACTION_RESPONSE_VERSIONED,
129+
requestZodSchema: HttpActionRequestSchema,
130+
responseZodSchema: HttpActionResponseSchema,
125131
});
126132

127133
return cbor.decode(new Uint8Array(responseData.output));

rivetkit-typescript/packages/rivetkit/src/client/utils.ts

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import * as cbor from "cbor-x";
22
import invariant from "invariant";
3+
import type { z } from "zod";
4+
import { HttpResponseErrorSchema } from "@/actor/client-protocol-schema-json/mod";
35
import type { Encoding } from "@/actor/protocol/serde";
46
import { assertUnreachable } from "@/common/utils";
57
import type { VersionedDataHandler } from "@/common/versioned-data";
@@ -33,11 +35,16 @@ export function messageLength(message: WebSocketMessage): number {
3335
assertUnreachable(message);
3436
}
3537

36-
export interface HttpRequestOpts<RequestBody, ResponseBody> {
38+
export interface HttpRequestOpts<
39+
RequestBody,
40+
ResponseBody,
41+
RequestJson = RequestBody,
42+
ResponseJson = ResponseBody,
43+
> {
3744
method: string;
3845
url: string;
3946
headers: Record<string, string>;
40-
body?: RequestBody;
47+
body?: RequestBody | RequestJson;
4148
encoding: Encoding;
4249
skipParseResponse?: boolean;
4350
signal?: AbortSignal;
@@ -46,12 +53,18 @@ export interface HttpRequestOpts<RequestBody, ResponseBody> {
4653
responseVersionedDataHandler:
4754
| VersionedDataHandler<ResponseBody>
4855
| undefined;
56+
requestZodSchema: z.ZodType<RequestJson>;
57+
responseZodSchema: z.ZodType<ResponseJson>;
4958
}
5059

5160
export async function sendHttpRequest<
5261
RequestBody = unknown,
5362
ResponseBody = unknown,
54-
>(opts: HttpRequestOpts<RequestBody, ResponseBody>): Promise<ResponseBody> {
63+
RequestJson = RequestBody,
64+
ResponseJson = ResponseBody,
65+
>(
66+
opts: HttpRequestOpts<RequestBody, ResponseBody, RequestJson, ResponseJson>,
67+
): Promise<ResponseBody | ResponseJson> {
5568
logger().debug({
5669
msg: "sending http request",
5770
url: opts.url,
@@ -64,10 +77,11 @@ export async function sendHttpRequest<
6477
if (opts.method === "POST" || opts.method === "PUT") {
6578
invariant(opts.body !== undefined, "missing body");
6679
contentType = contentTypeForEncoding(opts.encoding);
67-
bodyData = serializeWithEncoding<RequestBody>(
80+
bodyData = serializeWithEncoding<RequestBody, RequestJson>(
6881
opts.encoding,
6982
opts.body,
7083
opts.requestVersionedDataHandler,
84+
opts.requestZodSchema,
7185
);
7286
}
7387

@@ -108,6 +122,7 @@ export async function sendHttpRequest<
108122
opts.encoding,
109123
new Uint8Array(bufferResponse),
110124
HTTP_RESPONSE_ERROR_VERSIONED,
125+
HttpResponseErrorSchema,
111126
);
112127
} catch (error) {
113128
//logger().warn("failed to cleanly parse error, this is likely because a non-structured response is being served", {
@@ -151,16 +166,17 @@ export async function sendHttpRequest<
151166

152167
// Some requests don't need the success response to be parsed, so this can speed things up
153168
if (opts.skipParseResponse) {
154-
return undefined as ResponseBody;
169+
return undefined as ResponseBody | ResponseJson;
155170
}
156171

157172
// Parse the response based on encoding
158173
try {
159174
const buffer = new Uint8Array(await response.arrayBuffer());
160-
return deserializeWithEncoding(
175+
return deserializeWithEncoding<ResponseBody, ResponseJson>(
161176
opts.encoding,
162177
buffer,
163178
opts.responseVersionedDataHandler,
179+
opts.responseZodSchema,
164180
);
165181
} catch (error) {
166182
throw new HttpRequestError(`Failed to parse response: ${error}`, {

0 commit comments

Comments
 (0)