Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export interface ConnDriver {
sendMessage?(
actor: AnyActorInstance,
conn: AnyConn,
message: CachedSerializer<protocol.ToClient>,
message: CachedSerializer<any, any, any>,
): void;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export function createWebSocketSocket(
sendMessage: (
actor: AnyActorInstance,
conn: AnyConn,
message: CachedSerializer<protocol.ToClient>,
message: CachedSerializer<any, any, any>,
) => {
if (websocket.readyState !== DriverReadyState.OPEN) {
actor.rLog.warn({
Expand Down
36 changes: 26 additions & 10 deletions rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import * as cbor from "cbor-x";
import { ToClientSchema } from "@/actor/client-protocol-schema-json/mod";
import type * as protocol from "@/schemas/client-protocol/mod";
import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned";
import {
type ToClient as ToClientJson,
ToClientSchema,
} from "@/schemas/client-protocol-zod/mod";
import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils";
import type { AnyDatabaseProvider } from "../database";
import {
Expand Down Expand Up @@ -161,7 +164,7 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
return this.#stateManager.persistRaw;
}

[CONN_SEND_MESSAGE_SYMBOL](message: CachedSerializer<protocol.ToClient>) {
[CONN_SEND_MESSAGE_SYMBOL](message: CachedSerializer<any, any, any>) {
if (this[CONN_DRIVER_SYMBOL]) {
const driver = this[CONN_DRIVER_SYMBOL];
if (driver.sendMessage) {
Expand Down Expand Up @@ -194,19 +197,32 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
args,
connId: this.id,
});
const eventData = { name: eventName, args };
this[CONN_SEND_MESSAGE_SYMBOL](
new CachedSerializer<protocol.ToClient>(
{
new CachedSerializer(
eventData,
TO_CLIENT_VERSIONED,
ToClientSchema,
// JSON: args is the raw value (array of arguments)
(value): ToClientJson => ({
body: {
tag: "Event",
tag: "Event" as const,
val: {
name: eventName,
args: bufferToArrayBuffer(cbor.encode(args)),
name: value.name,
args: value.args,
},
},
},
TO_CLIENT_VERSIONED,
ToClientSchema,
}),
// BARE/CBOR: args needs to be CBOR-encoded to ArrayBuffer
(value): protocol.ToClient => ({
body: {
tag: "Event" as const,
val: {
name: value.name,
args: bufferToArrayBuffer(cbor.encode(value.args)),
},
},
}),
),
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import * as cbor from "cbor-x";
import { ToClientSchema } from "@/actor/client-protocol-schema-json/mod";
import type * as protocol from "@/schemas/client-protocol/mod";
import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned";
import {
type ToClient as ToClientJson,
ToClientSchema,
} from "@/schemas/client-protocol-zod/mod";
import { bufferToArrayBuffer } from "@/utils";
import {
CONN_PERSIST_SYMBOL,
Expand Down Expand Up @@ -180,18 +183,31 @@ export class EventManager<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
}

// Create serialized message
const toClientSerializer = new CachedSerializer<protocol.ToClient>(
{
const eventData = { name, args };
const toClientSerializer = new CachedSerializer(
eventData,
TO_CLIENT_VERSIONED,
ToClientSchema,
// JSON: args is the raw value (array of arguments)
(value): ToClientJson => ({
body: {
tag: "Event",
tag: "Event" as const,
val: {
name,
args: bufferToArrayBuffer(cbor.encode(args)),
name: value.name,
args: value.args,
},
},
},
TO_CLIENT_VERSIONED,
ToClientSchema,
}),
// BARE/CBOR: args needs to be CBOR-encoded to ArrayBuffer
(value): protocol.ToClient => ({
body: {
tag: "Event" as const,
val: {
name: value.name,
args: bufferToArrayBuffer(cbor.encode(value.args)),
},
},
}),
);

// Send to all subscribers
Expand Down
41 changes: 29 additions & 12 deletions rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import * as cbor from "cbor-x";
import invariant from "invariant";
import { ToClientSchema } from "@/actor/client-protocol-schema-json/mod";
import type { ActorKey } from "@/actor/mod";
import type { Client } from "@/client/client";
import { getBaseLogger, getIncludeTarget, type Logger } from "@/common/log";
Expand All @@ -11,6 +10,7 @@ import type { Registry } from "@/mod";
import { ACTOR_VERSIONED } from "@/schemas/actor-persist/versioned";
import type * as protocol from "@/schemas/client-protocol/mod";
import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned";
import { ToClientSchema } from "@/schemas/client-protocol-zod/mod";
import { EXTRA_ERROR_LOG, idToStr } from "@/utils";
import type { ActorConfig, InitContext } from "../config";
import type { ConnDriver } from "../conn/driver";
Expand Down Expand Up @@ -511,19 +511,26 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
await this.saveState({ immediate: true });

// Send init message
const initData = { actorId: this.id, connectionId: conn.id };
conn[CONN_SEND_MESSAGE_SYMBOL](
new CachedSerializer<protocol.ToClient>(
{
body: {
tag: "Init",
val: {
actorId: this.id,
connectionId: conn.id,
},
},
},
new CachedSerializer(
initData,
TO_CLIENT_VERSIONED,
ToClientSchema,
// JSON: identity conversion (no nested data to encode)
(value) => ({
body: {
tag: "Init" as const,
val: value,
},
}),
// BARE/CBOR: identity conversion (no nested data to encode)
(value) => ({
body: {
tag: "Init" as const,
val: value,
},
}),
),
);

Expand All @@ -532,7 +539,17 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {

// MARK: - Message Processing
async processMessage(
message: protocol.ToServer,
message: {
body:
| {
tag: "ActionRequest";
val: { id: bigint; name: string; args: unknown };
}
| {
tag: "SubscriptionRequest";
val: { eventName: string; subscribe: boolean };
};
},
conn: Conn<S, CP, CS, V, I, DB>,
) {
await processMessage(message, this, conn, {
Expand Down
Loading
Loading