Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import type { AnyConn } from "@/actor/conn/mod";
import type { AnyActorInstance } from "@/actor/instance/mod";
import type { CachedSerializer } from "@/actor/protocol/serde";
import type * as protocol from "@/schemas/client-protocol/mod";

export enum DriverReadyState {
UNKNOWN = -1,
Expand All @@ -15,6 +14,22 @@ export interface ConnDriver {
/** The type of driver. Used for debug purposes only. */
type: string;

/**
* If defined, this connection driver talks the RivetKit client driver (see
* schemas/client-protocol/).
*
* If enabled, events like `Init`, subscription events, etc. will be sent
* to this connection.
*/
rivetKitProtocol?: {
/** Sends a RivetKit client message. */
sendMessage(
actor: AnyActorInstance,
conn: AnyConn,
message: CachedSerializer<any, any, any>,
): void;
};

/**
* Unique request ID provided by the underlying provider. If none is
* available for this conn driver, a random UUID is generated.
Expand All @@ -29,12 +44,6 @@ export interface ConnDriver {
**/
hibernatable: boolean;

sendMessage?(
actor: AnyActorInstance,
conn: AnyConn,
message: CachedSerializer<any, any, any>,
): void;

/**
* This returns a promise since we commonly disconnect at the end of a program, and not waiting will cause the socket to not close cleanly.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,69 +23,71 @@ export function createWebSocketSocket(
requestId,
requestIdBuf,
hibernatable,
sendMessage: (
actor: AnyActorInstance,
conn: AnyConn,
message: CachedSerializer<any, any, any>,
) => {
if (!websocket) {
actor.rLog.warn({
msg: "websocket not open",
connId: conn.id,
});
return;
}
if (websocket.readyState !== DriverReadyState.OPEN) {
actor.rLog.warn({
msg: "attempting to send message to closed websocket, this is likely a bug in RivetKit",
connId: conn.id,
wsReadyState: websocket.readyState,
});
return;
}
rivetKitProtocol: {
sendMessage: (
actor: AnyActorInstance,
conn: AnyConn,
message: CachedSerializer<any, any, any>,
) => {
if (!websocket) {
actor.rLog.warn({
msg: "websocket not open",
connId: conn.id,
});
return;
}
if (websocket.readyState !== DriverReadyState.OPEN) {
actor.rLog.warn({
msg: "attempting to send message to closed websocket, this is likely a bug in RivetKit",
connId: conn.id,
wsReadyState: websocket.readyState,
});
return;
}

const serialized = message.serialize(encoding);
const serialized = message.serialize(encoding);

actor.rLog.debug({
msg: "sending websocket message",
encoding: encoding,
dataType: typeof serialized,
isUint8Array: serialized instanceof Uint8Array,
isArrayBuffer: serialized instanceof ArrayBuffer,
dataLength:
(serialized as any).byteLength ||
(serialized as any).length,
});
actor.rLog.debug({
msg: "sending websocket message",
encoding: encoding,
dataType: typeof serialized,
isUint8Array: serialized instanceof Uint8Array,
isArrayBuffer: serialized instanceof ArrayBuffer,
dataLength:
(serialized as any).byteLength ||
(serialized as any).length,
});

// Convert Uint8Array to ArrayBuffer for proper transmission
if (serialized instanceof Uint8Array) {
const buffer = serialized.buffer.slice(
serialized.byteOffset,
serialized.byteOffset + serialized.byteLength,
);
// Handle SharedArrayBuffer case
if (buffer instanceof SharedArrayBuffer) {
const arrayBuffer = new ArrayBuffer(buffer.byteLength);
new Uint8Array(arrayBuffer).set(new Uint8Array(buffer));
actor.rLog.debug({
msg: "converted SharedArrayBuffer to ArrayBuffer",
byteLength: arrayBuffer.byteLength,
});
websocket.send(arrayBuffer);
// Convert Uint8Array to ArrayBuffer for proper transmission
if (serialized instanceof Uint8Array) {
const buffer = serialized.buffer.slice(
serialized.byteOffset,
serialized.byteOffset + serialized.byteLength,
);
// Handle SharedArrayBuffer case
if (buffer instanceof SharedArrayBuffer) {
const arrayBuffer = new ArrayBuffer(buffer.byteLength);
new Uint8Array(arrayBuffer).set(new Uint8Array(buffer));
actor.rLog.debug({
msg: "converted SharedArrayBuffer to ArrayBuffer",
byteLength: arrayBuffer.byteLength,
});
websocket.send(arrayBuffer);
} else {
actor.rLog.debug({
msg: "sending ArrayBuffer",
byteLength: buffer.byteLength,
});
websocket.send(buffer);
}
} else {
actor.rLog.debug({
msg: "sending ArrayBuffer",
byteLength: buffer.byteLength,
msg: "sending string data",
length: (serialized as string).length,
});
websocket.send(buffer);
websocket.send(serialized);
}
} else {
actor.rLog.debug({
msg: "sending string data",
length: (serialized as string).length,
});
websocket.send(serialized);
}
},
},

disconnect: async (
Expand Down
19 changes: 16 additions & 3 deletions rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export type ConnId = string;
export type AnyConn = Conn<any, any, any, any, any, any>;

export const CONN_CONNECTED_SYMBOL = Symbol("connected");
export const CONN_SPEAKS_RIVETKIT_SYMBOL = Symbol("speaksRivetKit");
export const CONN_PERSIST_SYMBOL = Symbol("persist");
export const CONN_DRIVER_SYMBOL = Symbol("driver");
export const CONN_ACTOR_SYMBOL = Symbol("actor");
Expand Down Expand Up @@ -62,6 +63,10 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
/** Connections exist before being connected to an actor. If true, this connection has been connected. */
[CONN_CONNECTED_SYMBOL] = false;

[CONN_SPEAKS_RIVETKIT_SYMBOL](): boolean {
return this[CONN_DRIVER_SYMBOL]?.rivetKitProtocol !== undefined;
}

#assertConnected() {
if (!this[CONN_CONNECTED_SYMBOL])
throw new InternalError(
Expand Down Expand Up @@ -174,11 +179,12 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
[CONN_SEND_MESSAGE_SYMBOL](message: CachedSerializer<any, any, any>) {
if (this[CONN_DRIVER_SYMBOL]) {
const driver = this[CONN_DRIVER_SYMBOL];
if (driver.sendMessage) {
driver.sendMessage(this.#actor, this, message);

if (driver.rivetKitProtocol) {
driver.rivetKitProtocol.sendMessage(this.#actor, this, message);
} else {
this.#actor.rLog.debug({
msg: "conn driver does not support sending messages",
msg: "attempting to send RivetKit protocol message to connection that does not support it",
conn: this.id,
});
}
Expand All @@ -199,6 +205,13 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
*/
send(eventName: string, ...args: unknown[]) {
this.#assertConnected();
if (!this[CONN_SPEAKS_RIVETKIT_SYMBOL]) {
this.#actor.rLog.warn({
msg: "cannot send messages to this connection type",
connId: this.id,
connType: this[CONN_DRIVER_SYMBOL]?.type,
});
}

this.#actor.inspector.emitter.emit("eventFired", {
type: "event",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
CONN_PERSIST_RAW_SYMBOL,
CONN_PERSIST_SYMBOL,
CONN_SEND_MESSAGE_SYMBOL,
CONN_SPEAKS_RIVETKIT_SYMBOL,
Conn,
type ConnId,
} from "../conn/mod";
Expand Down Expand Up @@ -155,30 +156,31 @@ export class ConnectionManager<

conn[CONN_CONNECTED_SYMBOL] = true;

// TODO: Only do this for action messages
// Send init message
const initData = { actorId: this.#actor.id, connectionId: conn.id };
conn[CONN_SEND_MESSAGE_SYMBOL](
new CachedSerializer(
initData,
TO_CLIENT_VERSIONED,
ToClientSchema,
// JSON: identity conversion (no nested data to encode)
(value) => ({
body: {
tag: "Init" as const,
val: value,
},
}),
// BARE/CBOR: identity conversion (no nested data to encode)
(value) => ({
body: {
tag: "Init" as const,
val: value,
},
}),
),
);
if (conn[CONN_SPEAKS_RIVETKIT_SYMBOL]) {
const initData = { actorId: this.#actor.id, connectionId: conn.id };
conn[CONN_SEND_MESSAGE_SYMBOL](
new CachedSerializer(
initData,
TO_CLIENT_VERSIONED,
ToClientSchema,
// JSON: identity conversion (no nested data to encode)
(value) => ({
body: {
tag: "Init" as const,
val: value,
},
}),
// BARE/CBOR: identity conversion (no nested data to encode)
(value) => ({
body: {
tag: "Init" as const,
val: value,
},
}),
),
);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { bufferToArrayBuffer } from "@/utils";
import {
CONN_PERSIST_SYMBOL,
CONN_SEND_MESSAGE_SYMBOL,
CONN_SPEAKS_RIVETKIT_SYMBOL,
type Conn,
} from "../conn/mod";
import type { AnyDatabaseProvider } from "../database";
Expand Down Expand Up @@ -215,17 +216,21 @@ export class EventManager<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
// Send to all subscribers
let sentCount = 0;
for (const connection of subscribers) {
try {
connection[CONN_SEND_MESSAGE_SYMBOL](toClientSerializer);
sentCount++;
} catch (error) {
this.#actor.rLog.error({
msg: "failed to send event to connection",
eventName: name,
connId: connection.id,
error:
error instanceof Error ? error.message : String(error),
});
if (connection[CONN_SPEAKS_RIVETKIT_SYMBOL]) {
try {
connection[CONN_SEND_MESSAGE_SYMBOL](toClientSerializer);
sentCount++;
} catch (error) {
this.#actor.rLog.error({
msg: "failed to send event to connection",
eventName: name,
connId: connection.id,
error:
error instanceof Error
? error.message
: String(error),
});
}
}
}

Expand Down
Loading