diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts index 1c7da41096..8530841a3e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts @@ -1,7 +1,6 @@ import type { AnyConn } from "@/actor/conn/mod"; import type { AnyActorInstance } from "@/actor/instance/mod"; import type { CachedSerializer } from "@/actor/protocol/serde"; -import type * as protocol from "@/schemas/client-protocol/mod"; export enum DriverReadyState { UNKNOWN = -1, @@ -15,6 +14,22 @@ export interface ConnDriver { /** The type of driver. Used for debug purposes only. */ type: string; + /** + * If defined, this connection driver talks the RivetKit client driver (see + * schemas/client-protocol/). + * + * If enabled, events like `Init`, subscription events, etc. will be sent + * to this connection. + */ + rivetKitProtocol?: { + /** Sends a RivetKit client message. */ + sendMessage( + actor: AnyActorInstance, + conn: AnyConn, + message: CachedSerializer, + ): void; + }; + /** * Unique request ID provided by the underlying provider. If none is * available for this conn driver, a random UUID is generated. @@ -29,12 +44,6 @@ export interface ConnDriver { **/ hibernatable: boolean; - sendMessage?( - actor: AnyActorInstance, - conn: AnyConn, - message: CachedSerializer, - ): void; - /** * This returns a promise since we commonly disconnect at the end of a program, and not waiting will cause the socket to not close cleanly. */ diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts index b67959cf16..14dc4dbf62 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts @@ -23,69 +23,71 @@ export function createWebSocketSocket( requestId, requestIdBuf, hibernatable, - sendMessage: ( - actor: AnyActorInstance, - conn: AnyConn, - message: CachedSerializer, - ) => { - if (!websocket) { - actor.rLog.warn({ - msg: "websocket not open", - connId: conn.id, - }); - return; - } - if (websocket.readyState !== DriverReadyState.OPEN) { - actor.rLog.warn({ - msg: "attempting to send message to closed websocket, this is likely a bug in RivetKit", - connId: conn.id, - wsReadyState: websocket.readyState, - }); - return; - } + rivetKitProtocol: { + sendMessage: ( + actor: AnyActorInstance, + conn: AnyConn, + message: CachedSerializer, + ) => { + if (!websocket) { + actor.rLog.warn({ + msg: "websocket not open", + connId: conn.id, + }); + return; + } + if (websocket.readyState !== DriverReadyState.OPEN) { + actor.rLog.warn({ + msg: "attempting to send message to closed websocket, this is likely a bug in RivetKit", + connId: conn.id, + wsReadyState: websocket.readyState, + }); + return; + } - const serialized = message.serialize(encoding); + const serialized = message.serialize(encoding); - actor.rLog.debug({ - msg: "sending websocket message", - encoding: encoding, - dataType: typeof serialized, - isUint8Array: serialized instanceof Uint8Array, - isArrayBuffer: serialized instanceof ArrayBuffer, - dataLength: - (serialized as any).byteLength || - (serialized as any).length, - }); + actor.rLog.debug({ + msg: "sending websocket message", + encoding: encoding, + dataType: typeof serialized, + isUint8Array: serialized instanceof Uint8Array, + isArrayBuffer: serialized instanceof ArrayBuffer, + dataLength: + (serialized as any).byteLength || + (serialized as any).length, + }); - // Convert Uint8Array to ArrayBuffer for proper transmission - if (serialized instanceof Uint8Array) { - const buffer = serialized.buffer.slice( - serialized.byteOffset, - serialized.byteOffset + serialized.byteLength, - ); - // Handle SharedArrayBuffer case - if (buffer instanceof SharedArrayBuffer) { - const arrayBuffer = new ArrayBuffer(buffer.byteLength); - new Uint8Array(arrayBuffer).set(new Uint8Array(buffer)); - actor.rLog.debug({ - msg: "converted SharedArrayBuffer to ArrayBuffer", - byteLength: arrayBuffer.byteLength, - }); - websocket.send(arrayBuffer); + // Convert Uint8Array to ArrayBuffer for proper transmission + if (serialized instanceof Uint8Array) { + const buffer = serialized.buffer.slice( + serialized.byteOffset, + serialized.byteOffset + serialized.byteLength, + ); + // Handle SharedArrayBuffer case + if (buffer instanceof SharedArrayBuffer) { + const arrayBuffer = new ArrayBuffer(buffer.byteLength); + new Uint8Array(arrayBuffer).set(new Uint8Array(buffer)); + actor.rLog.debug({ + msg: "converted SharedArrayBuffer to ArrayBuffer", + byteLength: arrayBuffer.byteLength, + }); + websocket.send(arrayBuffer); + } else { + actor.rLog.debug({ + msg: "sending ArrayBuffer", + byteLength: buffer.byteLength, + }); + websocket.send(buffer); + } } else { actor.rLog.debug({ - msg: "sending ArrayBuffer", - byteLength: buffer.byteLength, + msg: "sending string data", + length: (serialized as string).length, }); - websocket.send(buffer); + websocket.send(serialized); } - } else { - actor.rLog.debug({ - msg: "sending string data", - length: (serialized as string).length, - }); - websocket.send(serialized); - } + }, }, disconnect: async ( diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts index d5854955c7..f81483bb51 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts @@ -23,6 +23,7 @@ export type ConnId = string; export type AnyConn = Conn; export const CONN_CONNECTED_SYMBOL = Symbol("connected"); +export const CONN_SPEAKS_RIVETKIT_SYMBOL = Symbol("speaksRivetKit"); export const CONN_PERSIST_SYMBOL = Symbol("persist"); export const CONN_DRIVER_SYMBOL = Symbol("driver"); export const CONN_ACTOR_SYMBOL = Symbol("actor"); @@ -62,6 +63,10 @@ export class Conn { /** Connections exist before being connected to an actor. If true, this connection has been connected. */ [CONN_CONNECTED_SYMBOL] = false; + [CONN_SPEAKS_RIVETKIT_SYMBOL](): boolean { + return this[CONN_DRIVER_SYMBOL]?.rivetKitProtocol !== undefined; + } + #assertConnected() { if (!this[CONN_CONNECTED_SYMBOL]) throw new InternalError( @@ -174,11 +179,12 @@ export class Conn { [CONN_SEND_MESSAGE_SYMBOL](message: CachedSerializer) { if (this[CONN_DRIVER_SYMBOL]) { const driver = this[CONN_DRIVER_SYMBOL]; - if (driver.sendMessage) { - driver.sendMessage(this.#actor, this, message); + + if (driver.rivetKitProtocol) { + driver.rivetKitProtocol.sendMessage(this.#actor, this, message); } else { this.#actor.rLog.debug({ - msg: "conn driver does not support sending messages", + msg: "attempting to send RivetKit protocol message to connection that does not support it", conn: this.id, }); } @@ -199,6 +205,13 @@ export class Conn { */ send(eventName: string, ...args: unknown[]) { this.#assertConnected(); + if (!this[CONN_SPEAKS_RIVETKIT_SYMBOL]) { + this.#actor.rLog.warn({ + msg: "cannot send messages to this connection type", + connId: this.id, + connType: this[CONN_DRIVER_SYMBOL]?.type, + }); + } this.#actor.inspector.emitter.emit("eventFired", { type: "event", diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts index 8ad2847c49..d6aa418a86 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts @@ -11,6 +11,7 @@ import { CONN_PERSIST_RAW_SYMBOL, CONN_PERSIST_SYMBOL, CONN_SEND_MESSAGE_SYMBOL, + CONN_SPEAKS_RIVETKIT_SYMBOL, Conn, type ConnId, } from "../conn/mod"; @@ -155,30 +156,31 @@ export class ConnectionManager< conn[CONN_CONNECTED_SYMBOL] = true; - // TODO: Only do this for action messages // Send init message - const initData = { actorId: this.#actor.id, connectionId: conn.id }; - conn[CONN_SEND_MESSAGE_SYMBOL]( - new CachedSerializer( - initData, - TO_CLIENT_VERSIONED, - ToClientSchema, - // JSON: identity conversion (no nested data to encode) - (value) => ({ - body: { - tag: "Init" as const, - val: value, - }, - }), - // BARE/CBOR: identity conversion (no nested data to encode) - (value) => ({ - body: { - tag: "Init" as const, - val: value, - }, - }), - ), - ); + if (conn[CONN_SPEAKS_RIVETKIT_SYMBOL]) { + const initData = { actorId: this.#actor.id, connectionId: conn.id }; + conn[CONN_SEND_MESSAGE_SYMBOL]( + new CachedSerializer( + initData, + TO_CLIENT_VERSIONED, + ToClientSchema, + // JSON: identity conversion (no nested data to encode) + (value) => ({ + body: { + tag: "Init" as const, + val: value, + }, + }), + // BARE/CBOR: identity conversion (no nested data to encode) + (value) => ({ + body: { + tag: "Init" as const, + val: value, + }, + }), + ), + ); + } } /** diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts index 20c252824b..344314ab79 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts @@ -9,6 +9,7 @@ import { bufferToArrayBuffer } from "@/utils"; import { CONN_PERSIST_SYMBOL, CONN_SEND_MESSAGE_SYMBOL, + CONN_SPEAKS_RIVETKIT_SYMBOL, type Conn, } from "../conn/mod"; import type { AnyDatabaseProvider } from "../database"; @@ -215,17 +216,21 @@ export class EventManager { // Send to all subscribers let sentCount = 0; for (const connection of subscribers) { - try { - connection[CONN_SEND_MESSAGE_SYMBOL](toClientSerializer); - sentCount++; - } catch (error) { - this.#actor.rLog.error({ - msg: "failed to send event to connection", - eventName: name, - connId: connection.id, - error: - error instanceof Error ? error.message : String(error), - }); + if (connection[CONN_SPEAKS_RIVETKIT_SYMBOL]) { + try { + connection[CONN_SEND_MESSAGE_SYMBOL](toClientSerializer); + sentCount++; + } catch (error) { + this.#actor.rLog.error({ + msg: "failed to send event to connection", + eventName: name, + connId: connection.id, + error: + error instanceof Error + ? error.message + : String(error), + }); + } } }