From 47cb55fdf6e76b0e62a0af2963e31ffef9c8e0c3 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Sat, 8 Nov 2025 20:05:46 -0800 Subject: [PATCH] chore(rivetkit): split ActorInstance logic in to multiple classes --- .../src/actor-handler-do.ts | 2 +- .../packages/rivetkit/src/actor/config.ts | 6 +- .../rivetkit/src/actor/conn-drivers.ts | 171 -- .../rivetkit/src/actor/conn-socket.ts | 9 - .../rivetkit/src/actor/conn/driver.ts | 45 + .../rivetkit/src/actor/conn/drivers/http.ts | 19 + .../src/actor/conn/drivers/websocket.ts | 103 + .../src/actor/{conn.ts => conn/mod.ts} | 101 +- .../src/actor/{ => contexts}/action.ts | 10 +- .../actor/{context.ts => contexts/actor.ts} | 14 +- .../packages/rivetkit/src/actor/definition.ts | 6 +- .../packages/rivetkit/src/actor/driver.ts | 2 +- .../packages/rivetkit/src/actor/instance.ts | 2217 ----------------- .../src/actor/instance/connection-manager.ts | 403 +++ .../src/actor/instance/event-manager.ts | 281 +++ .../rivetkit/src/actor/{ => instance}/kv.ts | 0 .../rivetkit/src/actor/instance/mod.ts | 1027 ++++++++ .../src/actor/{ => instance}/persisted.ts | 0 .../src/actor/instance/schedule-manager.ts | 349 +++ .../src/actor/instance/state-manager.ts | 440 ++++ .../packages/rivetkit/src/actor/mod.ts | 8 +- .../rivetkit/src/actor/protocol/old.ts | 10 +- .../rivetkit/src/actor/router-endpoints.ts | 65 +- .../packages/rivetkit/src/actor/router.ts | 17 +- .../packages/rivetkit/src/actor/schedule.ts | 2 +- .../rivetkit/src/actor/unstable-react.ts | 110 - .../rivetkit/src/driver-helpers/mod.ts | 2 +- .../src/drivers/engine/actor-driver.ts | 25 +- .../src/drivers/file-system/global-state.ts | 6 +- .../src/drivers/file-system/manager.ts | 2 +- .../rivetkit/tests/actor-types.test.ts | 2 +- 31 files changed, 2793 insertions(+), 2661 deletions(-) delete mode 100644 rivetkit-typescript/packages/rivetkit/src/actor/conn-drivers.ts delete mode 100644 rivetkit-typescript/packages/rivetkit/src/actor/conn-socket.ts create mode 100644 rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts create mode 100644 rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/http.ts create mode 100644 rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts rename rivetkit-typescript/packages/rivetkit/src/actor/{conn.ts => conn/mod.ts} (73%) rename rivetkit-typescript/packages/rivetkit/src/actor/{ => contexts}/action.ts (93%) rename rivetkit-typescript/packages/rivetkit/src/actor/{context.ts => contexts/actor.ts} (90%) delete mode 100644 rivetkit-typescript/packages/rivetkit/src/actor/instance.ts create mode 100644 rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts create mode 100644 rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts rename rivetkit-typescript/packages/rivetkit/src/actor/{ => instance}/kv.ts (100%) create mode 100644 rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts rename rivetkit-typescript/packages/rivetkit/src/actor/{ => instance}/persisted.ts (100%) create mode 100644 rivetkit-typescript/packages/rivetkit/src/actor/instance/schedule-manager.ts create mode 100644 rivetkit-typescript/packages/rivetkit/src/actor/instance/state-manager.ts delete mode 100644 rivetkit-typescript/packages/rivetkit/src/actor/unstable-react.ts diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/actor-handler-do.ts b/rivetkit-typescript/packages/cloudflare-workers/src/actor-handler-do.ts index e9cf43997a..2b0479c19c 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/actor-handler-do.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/actor-handler-do.ts @@ -203,7 +203,7 @@ export function createActorDurableObject( // Load the actor instance and trigger alarm const actor = await actorDriver.loadActor(actorId); - await actor._onAlarm(); + await actor.onAlarm(); } }; } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/config.ts b/rivetkit-typescript/packages/rivetkit/src/actor/config.ts index 38a40b9262..b7fa486025 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/config.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/config.ts @@ -1,8 +1,8 @@ import { z } from "zod"; import type { UniversalWebSocket } from "@/common/websocket-interface"; -import type { ActionContext } from "./action"; -import type { Conn } from "./conn"; -import type { ActorContext } from "./context"; +import type { Conn } from "./conn/mod"; +import type { ActionContext } from "./contexts/action"; +import type { ActorContext } from "./contexts/actor"; import type { AnyDatabaseProvider } from "./database"; export type InitContext = ActorContext< diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn-drivers.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn-drivers.ts deleted file mode 100644 index 9b8d26a354..0000000000 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn-drivers.ts +++ /dev/null @@ -1,171 +0,0 @@ -import type { WSContext } from "hono/ws"; -import type { WebSocket } from "ws"; -import type { AnyConn } from "@/actor/conn"; -import type { AnyActorInstance } from "@/actor/instance"; -import type { CachedSerializer, Encoding } from "@/actor/protocol/serde"; -import { encodeDataToString } from "@/actor/protocol/serde"; -import type { HonoWebSocketAdapter } from "@/manager/hono-websocket-adapter"; -import type * as protocol from "@/schemas/client-protocol/mod"; -import { assertUnreachable, type promiseWithResolvers } from "@/utils"; - -export enum ConnDriverKind { - WEBSOCKET = 0, - HTTP = 2, -} - -export enum ConnReadyState { - UNKNOWN = -1, - CONNECTING = 0, - OPEN = 1, - CLOSING = 2, - CLOSED = 3, -} - -export interface ConnDriverWebSocketState { - encoding: Encoding; - websocket: WSContext; - closePromise: ReturnType>; -} - -export type ConnDriverHttpState = Record; - -export type ConnDriverState = - | { [ConnDriverKind.WEBSOCKET]: ConnDriverWebSocketState } - | { [ConnDriverKind.HTTP]: ConnDriverHttpState }; - -export interface ConnDriver { - sendMessage?( - actor: AnyActorInstance, - conn: AnyConn, - state: State, - message: CachedSerializer, - ): void; - - /** - * This returns a promise since we commonly disconnect at the end of a program, and not waiting will cause the socket to not close cleanly. - */ - disconnect( - actor: AnyActorInstance, - conn: AnyConn, - state: State, - reason?: string, - ): Promise; - - /** - * Returns the ready state of the connection. - * This is used to determine if the connection is ready to send messages, or if the connection is stale. - */ - getConnectionReadyState( - actor: AnyActorInstance, - conn: AnyConn, - state: State, - ): ConnReadyState | undefined; -} - -// MARK: WebSocket -const WEBSOCKET_DRIVER: ConnDriver = { - sendMessage: ( - actor: AnyActorInstance, - conn: AnyConn, - state: ConnDriverWebSocketState, - message: CachedSerializer, - ) => { - if (state.websocket.readyState !== ConnReadyState.OPEN) { - actor.rLog.warn({ - msg: "attempting to send message to closed websocket, this is likely a bug in RivetKit", - connId: conn.id, - wsReadyState: state.websocket.readyState, - }); - return; - } - - const serialized = message.serialize(state.encoding); - - actor.rLog.debug({ - msg: "sending websocket message", - encoding: state.encoding, - dataType: typeof serialized, - isUint8Array: serialized instanceof Uint8Array, - isArrayBuffer: serialized instanceof ArrayBuffer, - dataLength: - (serialized as any).byteLength || (serialized as any).length, - }); - - // Convert Uint8Array to ArrayBuffer for proper transmission - if (serialized instanceof Uint8Array) { - const buffer = serialized.buffer.slice( - serialized.byteOffset, - serialized.byteOffset + serialized.byteLength, - ); - // Handle SharedArrayBuffer case - if (buffer instanceof SharedArrayBuffer) { - const arrayBuffer = new ArrayBuffer(buffer.byteLength); - new Uint8Array(arrayBuffer).set(new Uint8Array(buffer)); - actor.rLog.debug({ - msg: "converted SharedArrayBuffer to ArrayBuffer", - byteLength: arrayBuffer.byteLength, - }); - state.websocket.send(arrayBuffer); - } else { - actor.rLog.debug({ - msg: "sending ArrayBuffer", - byteLength: buffer.byteLength, - }); - state.websocket.send(buffer); - } - } else { - actor.rLog.debug({ - msg: "sending string data", - length: (serialized as string).length, - }); - state.websocket.send(serialized); - } - }, - - disconnect: async ( - _actor: AnyActorInstance, - _conn: AnyConn, - state: ConnDriverWebSocketState, - reason?: string, - ) => { - // Close socket - state.websocket.close(1000, reason); - - // Create promise to wait for socket to close gracefully - await state.closePromise.promise; - }, - - getConnectionReadyState: ( - _actor: AnyActorInstance, - _conn: AnyConn, - state: ConnDriverWebSocketState, - ): ConnReadyState | undefined => { - return state.websocket.readyState; - }, -}; - -// MARK: HTTP -const HTTP_DRIVER: ConnDriver = { - getConnectionReadyState(_actor, _conn) { - // TODO: This might not be the correct logic - return ConnReadyState.OPEN; - }, - disconnect: async () => { - // Noop - // TODO: Abort the request - }, -}; - -/** List of all connection drivers. */ -export const CONN_DRIVERS: Record> = { - [ConnDriverKind.WEBSOCKET]: WEBSOCKET_DRIVER, - [ConnDriverKind.HTTP]: HTTP_DRIVER, -}; - -export function getConnDriverKindFromState( - state: ConnDriverState, -): ConnDriverKind { - if (ConnDriverKind.WEBSOCKET in state) return ConnDriverKind.WEBSOCKET; - else if (ConnDriverKind.HTTP in state) return ConnDriverKind.HTTP; - else assertUnreachable(state); -} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn-socket.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn-socket.ts deleted file mode 100644 index 79f0745392..0000000000 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn-socket.ts +++ /dev/null @@ -1,9 +0,0 @@ -import type { ConnDriverState } from "./conn-drivers"; - -export interface ConnSocket { - /** This is the request ID provided by the given framework. If not provided this is a random UUID. */ - requestId: string; - requestIdBuf?: ArrayBuffer; - hibernatable: boolean; - driverState: ConnDriverState; -} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts new file mode 100644 index 0000000000..b53080e8c0 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts @@ -0,0 +1,45 @@ +import type { AnyConn } from "@/actor/conn/mod"; +import type { AnyActorInstance } from "@/actor/instance/mod"; +import type { CachedSerializer } from "@/actor/protocol/serde"; +import type * as protocol from "@/schemas/client-protocol/mod"; + +export enum DriverReadyState { + UNKNOWN = -1, + CONNECTING = 0, + OPEN = 1, + CLOSING = 2, + CLOSED = 3, +} + +export interface ConnDriver { + requestId: string; + requestIdBuf: ArrayBuffer | undefined; + hibernatable: boolean; + + sendMessage?( + actor: AnyActorInstance, + conn: AnyConn, + message: CachedSerializer, + ): void; + + /** + * This returns a promise since we commonly disconnect at the end of a program, and not waiting will cause the socket to not close cleanly. + */ + disconnect( + actor: AnyActorInstance, + conn: AnyConn, + reason?: string, + ): Promise; + + /** Terminates the connection without graceful handling. */ + terminate?(actor: AnyActorInstance, conn: AnyConn): void; + + /** + * Returns the ready state of the connection. + * This is used to determine if the connection is ready to send messages, or if the connection is stale. + */ + getConnectionReadyState( + actor: AnyActorInstance, + conn: AnyConn, + ): DriverReadyState | undefined; +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/http.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/http.ts new file mode 100644 index 0000000000..a332f2bbea --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/http.ts @@ -0,0 +1,19 @@ +import { type ConnDriver, DriverReadyState } from "../driver"; + +export type ConnHttpState = Record; + +export function createHttpSocket(): ConnDriver { + return { + requestId: crypto.randomUUID(), + requestIdBuf: undefined, + hibernatable: false, + getConnectionReadyState(_actor, _conn) { + // TODO: This might not be the correct logic + return DriverReadyState.OPEN; + }, + disconnect: async () => { + // Noop + // TODO: Configure with abort signals to abort the request + }, + }; +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts new file mode 100644 index 0000000000..ea4e1cc977 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts @@ -0,0 +1,103 @@ +import type { WSContext } from "hono/ws"; +import type { AnyConn } from "@/actor/conn/mod"; +import type { AnyActorInstance } from "@/actor/instance/mod"; +import type { CachedSerializer, Encoding } from "@/actor/protocol/serde"; +import type * as protocol from "@/schemas/client-protocol/mod"; +import { type ConnDriver, DriverReadyState } from "../driver"; + +export type ConnDriverWebSocketState = {}; + +export function createWebSocketSocket( + requestId: string, + requestIdBuf: ArrayBuffer | undefined, + hibernatable: boolean, + encoding: Encoding, + websocket: WSContext, + closePromise: Promise, +): ConnDriver { + return { + requestId, + requestIdBuf, + hibernatable, + sendMessage: ( + actor: AnyActorInstance, + conn: AnyConn, + message: CachedSerializer, + ) => { + if (websocket.readyState !== DriverReadyState.OPEN) { + actor.rLog.warn({ + msg: "attempting to send message to closed websocket, this is likely a bug in RivetKit", + connId: conn.id, + wsReadyState: websocket.readyState, + }); + return; + } + + const serialized = message.serialize(encoding); + + actor.rLog.debug({ + msg: "sending websocket message", + encoding: encoding, + dataType: typeof serialized, + isUint8Array: serialized instanceof Uint8Array, + isArrayBuffer: serialized instanceof ArrayBuffer, + dataLength: + (serialized as any).byteLength || + (serialized as any).length, + }); + + // Convert Uint8Array to ArrayBuffer for proper transmission + if (serialized instanceof Uint8Array) { + const buffer = serialized.buffer.slice( + serialized.byteOffset, + serialized.byteOffset + serialized.byteLength, + ); + // Handle SharedArrayBuffer case + if (buffer instanceof SharedArrayBuffer) { + const arrayBuffer = new ArrayBuffer(buffer.byteLength); + new Uint8Array(arrayBuffer).set(new Uint8Array(buffer)); + actor.rLog.debug({ + msg: "converted SharedArrayBuffer to ArrayBuffer", + byteLength: arrayBuffer.byteLength, + }); + websocket.send(arrayBuffer); + } else { + actor.rLog.debug({ + msg: "sending ArrayBuffer", + byteLength: buffer.byteLength, + }); + websocket.send(buffer); + } + } else { + actor.rLog.debug({ + msg: "sending string data", + length: (serialized as string).length, + }); + websocket.send(serialized); + } + }, + + disconnect: async ( + _actor: AnyActorInstance, + _conn: AnyConn, + reason?: string, + ) => { + // Close socket + websocket.close(1000, reason); + + // Create promise to wait for socket to close gracefully + await closePromise; + }, + + terminate: () => { + (websocket as any).terminate(); + }, + + getConnectionReadyState: ( + _actor: AnyActorInstance, + _conn: AnyConn, + ): DriverReadyState | undefined => { + return websocket.readyState; + }, + }; +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts similarity index 73% rename from rivetkit-typescript/packages/rivetkit/src/actor/conn.ts rename to rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts index 3d4170cee9..c4fdf59eee 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts @@ -4,17 +4,15 @@ import { isCborSerializable } from "@/common/utils"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils"; +import type { AnyDatabaseProvider } from "../database"; +import * as errors from "../errors"; import { - CONN_DRIVERS, - type ConnDriverState, - getConnDriverKindFromState, -} from "./conn-drivers"; -import type { ConnSocket } from "./conn-socket"; -import type { AnyDatabaseProvider } from "./database"; -import * as errors from "./errors"; -import { type ActorInstance, PERSIST_SYMBOL } from "./instance"; -import type { PersistedConn } from "./persisted"; -import { CachedSerializer } from "./protocol/serde"; + ACTOR_INSTANCE_PERSIST_SYMBOL, + type ActorInstance, +} from "../instance/mod"; +import type { PersistedConn } from "../instance/persisted"; +import { CachedSerializer } from "../protocol/serde"; +import type { ConnDriver } from "./driver"; export function generateConnRequestId(): string { return crypto.randomUUID(); @@ -24,6 +22,9 @@ export type ConnId = string; export type AnyConn = Conn; +export const CONN_PERSIST_SYMBOL = Symbol("persist"); +export const CONN_DRIVER_SYMBOL = Symbol("driver"); + /** * Represents a client connection to a actor. * @@ -46,7 +47,7 @@ export class Conn { * This will only be persisted if using hibernatable WebSockets. If not, * this is just used to hole state. */ - __persist!: PersistedConn; + [CONN_PERSIST_SYMBOL]!: PersistedConn; /** Raw persist object without the proxy wrapper */ #persistRaw: PersistedConn; @@ -54,22 +55,16 @@ export class Conn { /** Track if this connection's state has changed */ #changed = false; - get __driverState(): ConnDriverState | undefined { - return this.__socket?.driverState; - } - /** - * Socket connected to this connection. - * * If undefined, then nothing is connected to this. */ - __socket?: ConnSocket; + [CONN_DRIVER_SYMBOL]?: ConnDriver; public get params(): CP { - return this.__persist.params; + return this[CONN_PERSIST_SYMBOL].params; } - public get __stateEnabled() { + public get stateEnabled() { return this.#actor.connStateEnabled; } @@ -80,8 +75,9 @@ export class Conn { */ public get state(): CS { this.#validateStateEnabled(); - if (!this.__persist.state) throw new Error("state should exists"); - return this.__persist.state; + if (!this[CONN_PERSIST_SYMBOL].state) + throw new Error("state should exists"); + return this[CONN_PERSIST_SYMBOL].state; } /** @@ -91,14 +87,14 @@ export class Conn { */ public set state(value: CS) { this.#validateStateEnabled(); - this.__persist.state = value; + this[CONN_PERSIST_SYMBOL].state = value; } /** * Unique identifier for the connection. */ public get id(): ConnId { - return this.__persist.connId; + return this[CONN_PERSIST_SYMBOL].connId; } /** @@ -107,14 +103,16 @@ export class Conn { * If the underlying connection can hibernate. */ public get isHibernatable(): boolean { - if (!this.__persist.hibernatableRequestId) { + if (!this[CONN_PERSIST_SYMBOL].hibernatableRequestId) { return false; } return ( - this.#actor[PERSIST_SYMBOL].hibernatableConns.findIndex((conn) => + (this.#actor as any)[ + ACTOR_INSTANCE_PERSIST_SYMBOL + ].hibernatableConns.findIndex((conn: any) => arrayBuffersEqual( conn.hibernatableRequestId, - this.__persist.hibernatableRequestId!, + this[CONN_PERSIST_SYMBOL].hibernatableRequestId!, ), ) > -1 ); @@ -124,7 +122,7 @@ export class Conn { * Timestamp of the last time the connection was seen, i.e. the last time the connection was active and checked for liveness. */ public get lastSeen(): number { - return this.__persist.lastSeen; + return this[CONN_PERSIST_SYMBOL].lastSeen; } /** @@ -149,12 +147,12 @@ export class Conn { #setupPersistProxy(persist: PersistedConn) { // If this can't be proxied, return raw value if (persist === null || typeof persist !== "object") { - this.__persist = persist; + this[CONN_PERSIST_SYMBOL] = persist; return; } // Listen for changes to the object - this.__persist = onChange( + this[CONN_PERSIST_SYMBOL] = onChange( persist, ( path: string, @@ -188,7 +186,7 @@ export class Conn { }); // Notify actor that this connection has changed - this.#actor.__markConnChanged(this); + this.#actor.markConnChanged(this); }, { ignoreDetached: true }, ); @@ -216,29 +214,16 @@ export class Conn { } #validateStateEnabled() { - if (!this.__stateEnabled) { + if (!this.stateEnabled) { throw new errors.ConnStateNotEnabled(); } } - /** - * Sends a WebSocket message to the client. - * - * @param message - The message to send. - * - * @protected - */ - public _sendMessage(message: CachedSerializer) { - if (this.__driverState) { - const driverKind = getConnDriverKindFromState(this.__driverState); - const driver = CONN_DRIVERS[driverKind]; + public sendMessage(message: CachedSerializer) { + if (this[CONN_DRIVER_SYMBOL]) { + const driver = this[CONN_DRIVER_SYMBOL]; if (driver.sendMessage) { - driver.sendMessage( - this.#actor, - this, - (this.__driverState as any)[driverKind], - message, - ); + driver.sendMessage(this.#actor, this, message); } else { this.#actor.rLog.debug({ msg: "conn driver does not support sending messages", @@ -267,7 +252,7 @@ export class Conn { args, connId: this.id, }); - this._sendMessage( + this.sendMessage( new CachedSerializer( { body: { @@ -289,16 +274,10 @@ export class Conn { * @param reason - The reason for disconnection. */ public async disconnect(reason?: string) { - if (this.__socket && this.__driverState) { - const driverKind = getConnDriverKindFromState(this.__driverState); - const driver = CONN_DRIVERS[driverKind]; + if (this[CONN_DRIVER_SYMBOL]) { + const driver = this[CONN_DRIVER_SYMBOL]; if (driver.disconnect) { - driver.disconnect( - this.#actor, - this, - (this.__driverState as any)[driverKind], - reason, - ); + driver.disconnect(this.#actor, this, reason); } else { this.#actor.rLog.debug({ msg: "no disconnect handler for conn driver", @@ -306,7 +285,7 @@ export class Conn { }); } - this.#actor.__connDisconnected(this, true, this.__socket.requestId); + this.#actor.connDisconnected(this, true); } else { this.#actor.rLog.warn({ msg: "missing connection driver state for disconnect", @@ -314,6 +293,6 @@ export class Conn { }); } - this.__socket = undefined; + this[CONN_DRIVER_SYMBOL] = undefined; } } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/action.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/action.ts similarity index 93% rename from rivetkit-typescript/packages/rivetkit/src/actor/action.ts rename to rivetkit-typescript/packages/rivetkit/src/actor/contexts/action.ts index 8c2c21c305..a8953d883e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/action.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/action.ts @@ -2,11 +2,11 @@ import type { ActorKey } from "@/actor/mod"; import type { Client } from "@/client/client"; import type { Logger } from "@/common/log"; import type { Registry } from "@/registry/mod"; -import type { Conn, ConnId } from "./conn"; -import type { ActorContext } from "./context"; -import type { AnyDatabaseProvider, InferDatabaseClient } from "./database"; -import type { SaveStateOptions } from "./instance"; -import type { Schedule } from "./schedule"; +import type { Conn, ConnId } from "../conn/mod"; +import type { AnyDatabaseProvider, InferDatabaseClient } from "../database"; +import type { SaveStateOptions } from "../instance/state-manager"; +import type { Schedule } from "../schedule"; +import type { ActorContext } from "./actor"; /** * Context for a remote procedure call. diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/context.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/actor.ts similarity index 90% rename from rivetkit-typescript/packages/rivetkit/src/actor/context.ts rename to rivetkit-typescript/packages/rivetkit/src/actor/contexts/actor.ts index f6e8f2cacc..0096580854 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/context.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/actor.ts @@ -2,10 +2,10 @@ import type { ActorKey } from "@/actor/mod"; import type { Client } from "@/client/client"; import type { Logger } from "@/common/log"; import type { Registry } from "@/registry/mod"; -import type { Conn, ConnId } from "./conn"; -import type { AnyDatabaseProvider, InferDatabaseClient } from "./database"; -import type { ActorInstance, SaveStateOptions } from "./instance"; -import type { Schedule } from "./schedule"; +import type { Conn, ConnId } from "../conn/mod"; +import type { AnyDatabaseProvider, InferDatabaseClient } from "../database"; +import type { ActorInstance, SaveStateOptions } from "../instance/mod"; +import type { Schedule } from "../schedule"; /** * ActorContext class that provides access to actor methods and state @@ -60,7 +60,7 @@ export class ActorContext< * @param args - The arguments to send with the event. */ broadcast>(name: string, ...args: Args): void { - this.#actor._broadcast(name, ...args); + this.#actor.broadcast(name, ...args); return; } @@ -145,7 +145,7 @@ export class ActorContext< * Prevents the actor from sleeping until promise is complete. */ waitUntil(promise: Promise): void { - this.#actor._waitUntil(promise); + this.#actor.waitUntil(promise); } /** @@ -163,6 +163,6 @@ export class ActorContext< * @experimental */ sleep() { - this.#actor._startSleep(); + this.#actor.startSleep(); } } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/definition.ts b/rivetkit-typescript/packages/rivetkit/src/actor/definition.ts index 7f54f376a4..2111791e97 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/definition.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/definition.ts @@ -1,9 +1,9 @@ import type { RegistryConfig } from "@/registry/config"; -import type { ActionContext } from "./action"; import type { Actions, ActorConfig } from "./config"; -import type { ActorContext } from "./context"; +import type { ActionContext } from "./contexts/action"; +import type { ActorContext } from "./contexts/actor"; import type { AnyDatabaseProvider } from "./database"; -import { ActorInstance } from "./instance"; +import { ActorInstance } from "./instance/mod"; export type AnyActorDefinition = ActorDefinition< any, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/driver.ts b/rivetkit-typescript/packages/rivetkit/src/actor/driver.ts index bdea8d7abc..58ea192c7e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/driver.ts @@ -3,7 +3,7 @@ import type { AnyClient } from "@/client/client"; import type { ManagerDriver } from "@/manager/driver"; import type { RegistryConfig } from "@/registry/config"; import type { RunnerConfig } from "@/registry/run-config"; -import type { AnyActorInstance } from "./instance"; +import type { AnyActorInstance } from "./instance/mod"; export type ActorDriverBuilder = ( registryConfig: RegistryConfig, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance.ts deleted file mode 100644 index 685a73f33e..0000000000 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance.ts +++ /dev/null @@ -1,2217 +0,0 @@ -import * as cbor from "cbor-x"; -import invariant from "invariant"; -import onChange from "on-change"; -import type { ActorKey } from "@/actor/mod"; -import type { Client } from "@/client/client"; -import { getBaseLogger, getIncludeTarget, type Logger } from "@/common/log"; -import { isCborSerializable, stringifyError } from "@/common/utils"; -import type { UniversalWebSocket } from "@/common/websocket-interface"; -import { ActorInspector } from "@/inspector/actor"; -import type { Registry } from "@/mod"; -import type * as persistSchema from "@/schemas/actor-persist/mod"; -import { ACTOR_VERSIONED } from "@/schemas/actor-persist/versioned"; -import type * as protocol from "@/schemas/client-protocol/mod"; -import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; -import { - arrayBuffersEqual, - bufferToArrayBuffer, - EXTRA_ERROR_LOG, - idToStr, - promiseWithResolvers, - SinglePromiseQueue, -} from "@/utils"; -import { ActionContext } from "./action"; -import type { ActorConfig, OnConnectOptions } from "./config"; -import { Conn, type ConnId, generateConnRequestId } from "./conn"; -import { - CONN_DRIVERS, - ConnDriverKind, - getConnDriverKindFromState, -} from "./conn-drivers"; -import type { ConnSocket } from "./conn-socket"; -import { ActorContext } from "./context"; -import type { AnyDatabaseProvider, InferDatabaseClient } from "./database"; -import type { ActorDriver } from "./driver"; -import * as errors from "./errors"; -import { serializeActorKey } from "./keys"; -import { KEYS, makeConnKey } from "./kv"; -import type { - PersistedActor, - PersistedConn, - PersistedHibernatableConn, - PersistedScheduleEvent, -} from "./persisted"; -import { processMessage } from "./protocol/old"; -import { CachedSerializer } from "./protocol/serde"; -import { Schedule } from "./schedule"; -import { DeadlineError, deadline, isConnStatePath, isStatePath } from "./utils"; - -export const PERSIST_SYMBOL = Symbol("persist"); - -/** - * Options for the `_saveState` method. - */ -export interface SaveStateOptions { - /** - * Forces the state to be saved immediately. This function will return when the state has saved successfully. - */ - immediate?: boolean; - /** Bypass ready check for stopping. */ - allowStoppingState?: boolean; -} - -/** Actor type alias with all `any` types. Used for `extends` in classes referencing this actor. */ -export type AnyActorInstance = ActorInstance< - // biome-ignore lint/suspicious/noExplicitAny: Needs to be used in `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Needs to be used in `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Needs to be used in `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Needs to be used in `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Needs to be used in `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Needs to be used in `extends` - any ->; - -export type ExtractActorState = - A extends ActorInstance< - infer State, - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any - > - ? State - : never; - -export type ExtractActorConnParams = - A extends ActorInstance< - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any, - infer ConnParams, - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any - > - ? ConnParams - : never; - -export type ExtractActorConnState = - A extends ActorInstance< - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any, - infer ConnState, - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any, - // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` - any - > - ? ConnState - : never; - -enum CanSleep { - Yes, - NotReady, - ActiveConns, - ActiveHonoHttpRequests, - ActiveRawWebSockets, -} - -export class ActorInstance { - // Shared actor context for this instance - actorContext: ActorContext; - - /** Actor log, intended for the user to call */ - #log!: Logger; - - get log(): Logger { - invariant(this.#log, "log not configured"); - return this.#log; - } - - /** Runtime log, intended for internal actor logs */ - #rLog!: Logger; - - get rLog(): Logger { - invariant(this.#rLog, "log not configured"); - return this.#rLog; - } - - #sleepCalled = false; - #stopCalled = false; - - get isStopping() { - return this.#stopCalled; - } - - #persistChanged = false; - #isInOnStateChange = false; - - /** - * The proxied state that notifies of changes automatically. - * - * Any data that should be stored indefinitely should be held within this object. - */ - #persist!: PersistedActor; - - get [PERSIST_SYMBOL](): PersistedActor { - return this.#persist; - } - - /** Raw state without the proxy wrapper */ - #persistRaw!: PersistedActor; - - #persistWriteQueue = new SinglePromiseQueue(); - #alarmWriteQueue = new SinglePromiseQueue(); - - #lastSaveTime = 0; - #pendingSaveTimeout?: NodeJS.Timeout; - - #vars?: V; - - #backgroundPromises: Promise[] = []; - - #abortController = new AbortController(); - - #config: ActorConfig; - #actorDriver!: ActorDriver; - #inlineClient!: Client>; - #actorId!: string; - - #name!: string; - - get name(): string { - return this.#name; - } - - #key!: ActorKey; - - get key(): ActorKey { - return this.#key; - } - - #region!: string; - - get region(): string { - return this.#region; - } - - #ready = false; - - #connections = new Map>(); - - get conns(): Map> { - return this.#connections; - } - - #subscriptionIndex = new Map>>(); - #changedConnections = new Set(); - - #sleepTimeout?: NodeJS.Timeout; - - /** - * Track active HTTP requests through Hono router so sleep logic can - * account for them. Does not include WebSockets. - **/ - #activeHonoHttpRequests = 0; - #activeRawWebSockets = new Set(); - - #schedule!: Schedule; - - get schedule(): Schedule { - return this.#schedule; - } - - #db!: InferDatabaseClient; - - /** - * Gets the database. - * @experimental - */ - get db(): InferDatabaseClient { - if (!this.#db) { - throw new errors.DatabaseNotEnabled(); - } - return this.#db; - } - - #inspector = new ActorInspector(() => { - return { - isDbEnabled: async () => { - return this.#db !== undefined; - }, - getDb: async () => { - return this.db; - }, - isStateEnabled: async () => { - return this.stateEnabled; - }, - getState: async () => { - this.#validateStateEnabled(); - - // Must return from `#persistRaw` in order to not return the `onchange` proxy - return this.#persistRaw.state as Record as unknown; - }, - getRpcs: async () => { - return Object.keys(this.#config.actions); - }, - getConnections: async () => { - return Array.from(this.#connections.entries()).map( - ([id, conn]) => ({ - id, - params: conn.params as any, - state: conn.__stateEnabled ? conn.state : undefined, - subscriptions: conn.subscriptions.size, - lastSeen: conn.lastSeen, - stateEnabled: conn.__stateEnabled, - isHibernatable: conn.isHibernatable, - hibernatableRequestId: conn.__persist - .hibernatableRequestId - ? idToStr(conn.__persist.hibernatableRequestId) - : undefined, - driver: conn.__driverState - ? getConnDriverKindFromState(conn.__driverState) - : undefined, - }), - ); - }, - setState: async (state: unknown) => { - this.#validateStateEnabled(); - - // Must set on `#persist` instead of `#persistRaw` in order to ensure that the `Proxy` is correctly configured - // - // We have to use `...` so `on-change` recognizes the changes to `state` (i.e. set #persistChanged` to true). This is because: - // 1. In `getState`, we returned the value from `persistRaw`, which does not have the Proxy to monitor state changes - // 2. If we were to assign `state` to `#persist.s`, `on-change` would assume nothing changed since `state` is still === `#persist.s` since we returned a reference in `getState` - this.#persist.state = { ...(state as S) }; - await this.saveState({ immediate: true }); - }, - executeAction: async (name, params) => { - const requestId = generateConnRequestId(); - const conn = await this.createConn( - { - requestId: requestId, - hibernatable: false, - driverState: { [ConnDriverKind.HTTP]: {} }, - }, - undefined, - undefined, - ); - - try { - return await this.executeAction( - new ActionContext(this.actorContext, conn), - name, - params || [], - ); - } finally { - this.__connDisconnected(conn, true, requestId); - } - }, - }; - }); - - get id() { - return this.#actorId; - } - - get inlineClient(): Client> { - return this.#inlineClient; - } - - get inspector() { - return this.#inspector; - } - - get #sleepingSupported(): boolean { - return this.#actorDriver.startSleep !== undefined; - } - - /** - * This constructor should never be used directly. - * - * Constructed in {@link ActorInstance.start}. - * - * @private - */ - constructor(config: ActorConfig) { - this.#config = config; - this.actorContext = new ActorContext(this); - } - - // MARK: Initialization - async start( - actorDriver: ActorDriver, - inlineClient: Client>, - actorId: string, - name: string, - key: ActorKey, - region: string, - ) { - const logParams = { - actor: name, - key: serializeActorKey(key), - actorId, - }; - - const extraLogParams = actorDriver.getExtraActorLogParams?.(); - if (extraLogParams) Object.assign(logParams, extraLogParams); - - this.#log = getBaseLogger().child( - Object.assign( - getIncludeTarget() ? { target: "actor" } : {}, - logParams, - ), - ); - this.#rLog = getBaseLogger().child( - Object.assign( - getIncludeTarget() ? { target: "actor-runtime" } : {}, - logParams, - ), - ); - - this.#actorDriver = actorDriver; - this.#inlineClient = inlineClient; - this.#actorId = actorId; - this.#name = name; - this.#key = key; - this.#region = region; - this.#schedule = new Schedule(this); - - // Read initial state from KV storage - const [persistDataBuffer] = await this.#actorDriver.kvBatchGet( - this.#actorId, - [KEYS.PERSIST_DATA], - ); - invariant( - persistDataBuffer !== null, - "persist data has not been set, it should be set when initialized", - ); - const bareData = - ACTOR_VERSIONED.deserializeWithEmbeddedVersion(persistDataBuffer); - const persistData = this.#convertFromBarePersisted(bareData); - - if (persistData.hasInitialized) { - // List all connection keys - const connEntries = await this.#actorDriver.kvListPrefix( - this.#actorId, - KEYS.CONN_PREFIX, - ); - - // Decode connections - const connections: PersistedConn[] = []; - for (const [_key, value] of connEntries) { - try { - const conn = cbor.decode(value) as PersistedConn; - connections.push(conn); - } catch (error) { - this.#rLog.error({ - msg: "failed to decode connection", - error: stringifyError(error), - }); - } - } - - this.#rLog.info({ - msg: "actor restoring", - connections: connections.length, - hibernatableWebSockets: persistData.hibernatableConns.length, - }); - - // Set initial state - this.#initPersistProxy(persistData); - - // Create connection instances - for (const connPersist of connections) { - // Create connections - const conn = new Conn(this, connPersist); - this.#connections.set(conn.id, conn); - - // Register event subscriptions - for (const sub of connPersist.subscriptions) { - this.#addSubscription(sub.eventName, conn, true); - } - } - } else { - this.#rLog.info({ msg: "actor creating" }); - - // Initialize actor state - let stateData: unknown; - if (this.stateEnabled) { - this.#rLog.info({ msg: "actor state initializing" }); - - if ("createState" in this.#config) { - this.#config.createState; - - // Convert state to undefined since state is not defined yet here - stateData = await this.#config.createState( - this.actorContext as unknown as ActorContext< - undefined, - undefined, - undefined, - undefined, - undefined, - undefined - >, - persistData.input!, - ); - } else if ("state" in this.#config) { - stateData = structuredClone(this.#config.state); - } else { - throw new Error( - "Both 'createState' or 'state' were not defined", - ); - } - } else { - this.#rLog.debug({ msg: "state not enabled" }); - } - - // Save state and mark as initialized - persistData.state = stateData as S; - persistData.hasInitialized = true; - - // Update state - this.#rLog.debug({ msg: "writing state" }); - const bareData = this.#convertToBarePersisted(persistData); - await this.#actorDriver.kvBatchPut(this.#actorId, [ - [ - KEYS.PERSIST_DATA, - ACTOR_VERSIONED.serializeWithEmbeddedVersion(bareData), - ], - ]); - - this.#initPersistProxy(persistData); - - // Notify creation - if (this.#config.onCreate) { - await this.#config.onCreate( - this.actorContext, - persistData.input!, - ); - } - } - - // TODO: Exit process if this errors - if (this.#varsEnabled) { - let vars: V | undefined; - if ("createVars" in this.#config) { - const dataOrPromise = this.#config.createVars( - this.actorContext as unknown as ActorContext< - undefined, - undefined, - undefined, - undefined, - undefined, - any - >, - this.#actorDriver.getContext(this.#actorId), - ); - if (dataOrPromise instanceof Promise) { - vars = await deadline( - dataOrPromise, - this.#config.options.createVarsTimeout, - ); - } else { - vars = dataOrPromise; - } - } else if ("vars" in this.#config) { - vars = structuredClone(this.#config.vars); - } else { - throw new Error( - "Could not variables from 'createVars' or 'vars'", - ); - } - this.#vars = vars; - } - - // TODO: Exit process if this errors - this.#rLog.info({ msg: "actor starting" }); - if (this.#config.onStart) { - const result = this.#config.onStart(this.actorContext); - if (result instanceof Promise) { - await result; - } - } - - // Setup Database - if ("db" in this.#config && this.#config.db) { - const client = await this.#config.db.createClient({ - getDatabase: () => actorDriver.getDatabase(this.#actorId), - }); - this.#rLog.info({ msg: "database migration starting" }); - await this.#config.db.onMigrate?.(client); - this.#rLog.info({ msg: "database migration complete" }); - this.#db = client; - } - - // Set alarm for next scheduled event if any exist after finishing initiation sequence - if (this.#persist.scheduledEvents.length > 0) { - await this.#queueSetAlarm( - this.#persist.scheduledEvents[0].timestamp, - ); - } - - this.#rLog.info({ msg: "actor ready" }); - this.#ready = true; - - // Must be called after setting `#ready` or else it will not schedule sleep - this.#resetSleepTimer(); - - // Trigger any pending alarms - await this._onAlarm(); - } - - #assertReady(allowStoppingState: boolean = false) { - if (!this.#ready) throw new errors.InternalError("Actor not ready"); - if (!allowStoppingState && this.#stopCalled) - throw new errors.InternalError("Actor is stopping"); - } - - /** - * Check if the actor is ready to handle requests. - */ - isReady(): boolean { - return this.#ready; - } - - // MARK: Stop - /** - * For the engine: - * 1. Engine runner receives CommandStopActor - * 2. Engine runner calls _onStop and waits for it to finish - * 3. Engine runner publishes EventActorStateUpdate with ActorStateSTop - */ - async _onStop() { - if (this.#stopCalled) { - this.#rLog.warn({ msg: "already stopping actor" }); - return; - } - this.#stopCalled = true; - - this.#rLog.info({ msg: "actor stopping" }); - - if (this.#sleepTimeout) { - clearTimeout(this.#sleepTimeout); - this.#sleepTimeout = undefined; - } - - // Abort any listeners waiting for shutdown - try { - this.#abortController.abort(); - } catch {} - - // Call onStop lifecycle hook if defined - if (this.#config.onStop) { - try { - this.#rLog.debug({ msg: "calling onStop" }); - const result = this.#config.onStop(this.actorContext); - if (result instanceof Promise) { - await deadline(result, this.#config.options.onStopTimeout); - } - this.#rLog.debug({ msg: "onStop completed" }); - } catch (error) { - if (error instanceof DeadlineError) { - this.#rLog.error({ msg: "onStop timed out" }); - } else { - this.#rLog.error({ - msg: "error in onStop", - error: stringifyError(error), - }); - } - } - } - - const promises: Promise[] = []; - - // Disconnect existing non-hibernatable connections - for (const connection of this.#connections.values()) { - if (!connection.isHibernatable) { - this.#rLog.debug({ - msg: "disconnecting non-hibernatable connection on actor stop", - connId: connection.id, - }); - promises.push(connection.disconnect()); - } - - // TODO: Figure out how to abort HTTP requests on shutdown. This - // might already be handled by the engine runner tunnel shutdown. - } - - // Wait for any background tasks to finish, with timeout - await this.#waitBackgroundPromises( - this.#config.options.waitUntilTimeout, - ); - - // Clear timeouts - if (this.#pendingSaveTimeout) clearTimeout(this.#pendingSaveTimeout); - - // Write state - await this.saveState({ immediate: true, allowStoppingState: true }); - - // Await all `close` event listeners with 1.5 second timeout - const res = Promise.race([ - Promise.all(promises).then(() => false), - new Promise((res) => - globalThis.setTimeout(() => res(true), 1500), - ), - ]); - - if (await res) { - this.#rLog.warn({ - msg: "timed out waiting for connections to close, shutting down anyway", - }); - } - - // Wait for queues to finish - if (this.#persistWriteQueue.runningDrainLoop) - await this.#persistWriteQueue.runningDrainLoop; - if (this.#alarmWriteQueue.runningDrainLoop) - await this.#alarmWriteQueue.runningDrainLoop; - } - - /** Abort signal that fires when the actor is stopping. */ - get abortSignal(): AbortSignal { - return this.#abortController.signal; - } - - // MARK: Sleep - /** - * Reset timer from the last actor interaction that allows it to be put to sleep. - * - * This should be called any time a sleep-related event happens: - * - Connection opens (will clear timer) - * - Connection closes (will schedule timer if there are no open connections) - * - Alarm triggers (will reset timer) - * - * We don't need to call this on events like individual action calls, since there will always be a connection open for these. - **/ - #resetSleepTimer() { - if (this.#config.options.noSleep || !this.#sleepingSupported) return; - - // Don't sleep if already stopping - if (this.#stopCalled) return; - - const canSleep = this.#canSleep(); - - this.#rLog.debug({ - msg: "resetting sleep timer", - canSleep: CanSleep[canSleep], - existingTimeout: !!this.#sleepTimeout, - timeout: this.#config.options.sleepTimeout, - }); - - if (this.#sleepTimeout) { - clearTimeout(this.#sleepTimeout); - this.#sleepTimeout = undefined; - } - - // Don't set a new timer if already sleeping - if (this.#sleepCalled) return; - - if (canSleep === CanSleep.Yes) { - this.#sleepTimeout = setTimeout(() => { - this._startSleep(); - }, this.#config.options.sleepTimeout); - } - } - - /** If this actor can be put in a sleeping state. */ - #canSleep(): CanSleep { - if (!this.#ready) return CanSleep.NotReady; - - // Do not sleep if Hono HTTP requests are in-flight - if (this.#activeHonoHttpRequests > 0) - return CanSleep.ActiveHonoHttpRequests; - - // TODO: When WS hibernation is ready, update this to only count non-hibernatable websockets - // Do not sleep if there are raw websockets open - if (this.#activeRawWebSockets.size > 0) - return CanSleep.ActiveRawWebSockets; - - // Check for active conns. This will also cover active actions, since all actions have a connection. - for (const conn of this.#connections.values()) { - // TODO: Enable this when hibernation is implemented. We're waiting on support for Guard to not auto-wake the actor if it sleeps. - // if (!conn.isHibernatable) - // return false; - - // if (!conn.isHibernatable) return CanSleep.ActiveConns; - return CanSleep.ActiveConns; - } - - return CanSleep.Yes; - } - - /** - * Puts an actor to sleep. This should just start the sleep sequence, most shutdown logic should be in _stop (which is called by the ActorDriver when sleeping). - * - * For the engine, this will: - * 1. Publish EventActorIntent with ActorIntentSleep (via driver.startSleep) - * 2. Engine runner will wait for CommandStopActor - * 3. Engine runner will call _onStop and wait for it to finish - * 4. Engine runner will publish EventActorStateUpdate with ActorStateSTop - **/ - _startSleep() { - if (this.#stopCalled) { - this.#rLog.debug({ - msg: "cannot call _startSleep if actor already stopping", - }); - return; - } - - // IMPORTANT: #sleepCalled should have no effect on the actor's - // behavior aside from preventing calling _startSleep twice. Wait for - // `_onStop` before putting in a stopping state. - if (this.#sleepCalled) { - this.#rLog.warn({ - msg: "cannot call _startSleep twice, actor already sleeping", - }); - return; - } - this.#sleepCalled = true; - - // NOTE: Publishes ActorIntentSleep - const sleep = this.#actorDriver.startSleep?.bind( - this.#actorDriver, - this.#actorId, - ); - invariant(this.#sleepingSupported, "sleeping not supported"); - invariant(sleep, "no sleep on driver"); - - this.#rLog.info({ msg: "actor sleeping" }); - - // Schedule sleep to happen on the next tick. This allows for any action that calls _sleep to complete. - setImmediate(() => { - // The actor driver should call stop when ready to stop - // - // This will call _stop once Pegboard responds with the new status - sleep(); - }); - } - - /** - * Called by router middleware when an HTTP request begins. - */ - __beginHonoHttpRequest() { - this.#activeHonoHttpRequests++; - this.#resetSleepTimer(); - } - - /** - * Called by router middleware when an HTTP request ends. - */ - __endHonoHttpRequest() { - this.#activeHonoHttpRequests--; - if (this.#activeHonoHttpRequests < 0) { - this.#activeHonoHttpRequests = 0; - this.#rLog.warn({ - msg: "active hono requests went below 0, this is a RivetKit bug", - ...EXTRA_ERROR_LOG, - }); - } - this.#resetSleepTimer(); - } - - // MARK: State - /** - * Gets the current state. - * - * Changing properties of this value will automatically be persisted. - */ - get state(): S { - this.#validateStateEnabled(); - return this.#persist.state; - } - - /** - * Sets the current state. - * - * This property will automatically be persisted. - */ - set state(value: S) { - this.#validateStateEnabled(); - this.#persist.state = value; - } - - get stateEnabled() { - return "createState" in this.#config || "state" in this.#config; - } - - #validateStateEnabled() { - if (!this.stateEnabled) { - throw new errors.StateNotEnabled(); - } - } - - get connStateEnabled() { - return "createConnState" in this.#config || "connState" in this.#config; - } - - get vars(): V { - this.#validateVarsEnabled(); - invariant(this.#vars !== undefined, "vars not enabled"); - return this.#vars; - } - - get #varsEnabled() { - return "createVars" in this.#config || "vars" in this.#config; - } - - #validateVarsEnabled() { - if (!this.#varsEnabled) { - throw new errors.VarsNotEnabled(); - } - } - - /** - * Forces the state to get saved. - * - * This is helpful if running a long task that may fail later or when - * running a background job that updates the state. - * - * @param opts - Options for saving the state. - */ - async saveState(opts: SaveStateOptions) { - this.#assertReady(opts.allowStoppingState); - - this.#rLog.debug({ - msg: "saveState called", - persistChanged: this.#persistChanged, - allowStoppingState: opts.allowStoppingState, - immediate: opts.immediate, - }); - - if (this.#persistChanged) { - if (opts.immediate) { - // Save immediately - await this.#savePersistInner(); - } else { - // Create callback - if (!this.#onPersistSavedPromise) { - this.#onPersistSavedPromise = promiseWithResolvers(); - } - - // Save state throttled - this.#savePersistThrottled(); - - // Wait for save - await this.#onPersistSavedPromise.promise; - } - } - } - - /** Promise used to wait for a save to complete. This is required since you cannot await `#saveStateThrottled`. */ - #onPersistSavedPromise?: ReturnType>; - - /** Throttled save state method. Used to write to KV at a reasonable cadence. */ - #savePersistThrottled() { - const now = Date.now(); - const timeSinceLastSave = now - this.#lastSaveTime; - const saveInterval = this.#config.options.stateSaveInterval; - - // If we're within the throttle window and not already scheduled, schedule the next save. - if (timeSinceLastSave < saveInterval) { - if (this.#pendingSaveTimeout === undefined) { - this.#pendingSaveTimeout = setTimeout(() => { - this.#pendingSaveTimeout = undefined; - this.#savePersistInner(); - }, saveInterval - timeSinceLastSave); - } - } else { - // If we're outside the throttle window, save immediately - this.#savePersistInner(); - } - } - - /** Saves the state to KV. You probably want to use #saveStateThrottled instead except for a few edge cases. */ - async #savePersistInner() { - try { - this.#lastSaveTime = Date.now(); - - const hasChanges = - this.#persistChanged || this.#changedConnections.size > 0; - - if (hasChanges) { - const finished = this.#persistWriteQueue.enqueue(async () => { - this.#rLog.debug({ - msg: "saving persist", - actorChanged: this.#persistChanged, - connectionsChanged: this.#changedConnections.size, - }); - - await this.#writePersistedData(); - - this.#rLog.debug({ msg: "persist saved" }); - }); - - await finished; - } - - this.#onPersistSavedPromise?.resolve(); - } catch (error) { - this.#rLog.error({ - msg: "error saving persist", - error: stringifyError(error), - }); - this.#onPersistSavedPromise?.reject(error); - throw error; - } - } - - async #writePersistedData() { - const entries: [Uint8Array, Uint8Array][] = []; - - // Save actor state if changed - if (this.#persistChanged) { - this.#persistChanged = false; - - // Prepare actor state - const bareData = this.#convertToBarePersisted(this.#persistRaw); - - // Key [1] for actor persist data - entries.push([ - KEYS.PERSIST_DATA, - ACTOR_VERSIONED.serializeWithEmbeddedVersion(bareData), - ]); - } - - // Save changed connections - if (this.#changedConnections.size > 0) { - for (const connId of this.#changedConnections) { - const conn = this.#connections.get(connId); - if (conn) { - const connData = cbor.encode(conn.persistRaw); - entries.push([makeConnKey(connId), connData]); - conn.markSaved(); - } - } - this.#changedConnections.clear(); - } - - // Write all entries in batch - if (entries.length > 0) { - await this.#actorDriver.kvBatchPut(this.#actorId, entries); - } - } - - /** - * Creates proxy for `#persist` that handles automatically flagging when state needs to be updated. - */ - #initPersistProxy(target: PersistedActor) { - // Set raw persist object - this.#persistRaw = target; - - // TODO: Allow disabling in production - // If this can't be proxied, return raw value - if (target === null || typeof target !== "object") { - let invalidPath = ""; - if ( - !isCborSerializable( - target, - (path) => { - invalidPath = path; - }, - "", - ) - ) { - throw new errors.InvalidStateType({ path: invalidPath }); - } - return target; - } - - // Unsubscribe from old state - if (this.#persist) { - onChange.unsubscribe(this.#persist); - } - - // Listen for changes to the object in order to automatically write state - this.#persist = onChange( - target, - // biome-ignore lint/suspicious/noExplicitAny: Don't know types in proxy - ( - path: string, - value: any, - _previousValue: any, - _applyData: any, - ) => { - const actorStatePath = isStatePath(path); - const connStatePath = isConnStatePath(path); - - // Validate CBOR serializability for state changes - if (actorStatePath || connStatePath) { - let invalidPath = ""; - if ( - !isCborSerializable( - value, - (invalidPathPart) => { - invalidPath = invalidPathPart; - }, - "", - ) - ) { - throw new errors.InvalidStateType({ - path: path + (invalidPath ? `.${invalidPath}` : ""), - }); - } - } - - this.#rLog.debug({ - msg: "onChange triggered, setting persistChanged=true", - path, - }); - this.#persistChanged = true; - - // Inform the inspector about state changes (only for state path) - if (actorStatePath) { - this.inspector.emitter.emit( - "stateUpdated", - this.#persist.state, - ); - } - - // Call onStateChange if it exists - // - // Skip if we're already inside onStateChange to prevent infinite recursion - if ( - actorStatePath && - this.#config.onStateChange && - this.#ready && - !this.#isInOnStateChange - ) { - try { - this.#isInOnStateChange = true; - this.#config.onStateChange( - this.actorContext, - this.#persistRaw.state, - ); - } catch (error) { - this.#rLog.error({ - msg: "error in `_onStateChange`", - error: stringifyError(error), - }); - } finally { - this.#isInOnStateChange = false; - } - } - - // State will be flushed at the end of the action - }, - { ignoreDetached: true }, - ); - } - - // MARK: Connections - __getConnForId(id: string): Conn | undefined { - return this.#connections.get(id); - } - - /** - * Mark a connection as changed so it will be persisted on next save - */ - __markConnChanged(conn: Conn) { - this.#changedConnections.add(conn.id); - this.#rLog.debug({ - msg: "marked connection as changed", - connId: conn.id, - totalChanged: this.#changedConnections.size, - }); - } - - /** - * Call when conn is disconnected. - * - * If a clean diconnect, will be removed immediately. - * - * If not a clean disconnect, will keep the connection alive for a given interval to wait for reconnect. - */ - __connDisconnected( - conn: Conn, - wasClean: boolean, - requestId: string, - ) { - // If socket ID is provided, check if it matches the current socket ID - // If it doesn't match, this is a stale disconnect event from an old socket - if ( - requestId && - conn.__socket && - requestId !== conn.__socket.requestId - ) { - this.#rLog.debug({ - msg: "ignoring stale disconnect event", - connId: conn.id, - eventRequestId: requestId, - currentRequestId: conn.__socket.requestId, - }); - return; - } - - if (wasClean) { - // Disconnected cleanly, remove the conn - - this.#removeConn(conn); - } else { - // Disconnected uncleanly, allow reconnection - - if (!conn.__driverState) { - this.rLog.warn("called conn disconnected without driver state"); - } - - // Update last seen so we know when to clean it up - conn.__persist.lastSeen = Date.now(); - - // Remove socket - conn.__socket = undefined; - - // Update sleep - this.#resetSleepTimer(); - } - } - - /** - * Removes a connection and cleans up its resources. - */ - #removeConn(conn: Conn) { - // Remove conn from KV - const key = makeConnKey(conn.id); - this.#actorDriver - .kvBatchDelete(this.#actorId, [key]) - .then(() => { - this.#rLog.debug({ - msg: "removed connection from KV", - connId: conn.id, - }); - }) - .catch((err) => { - this.#rLog.error({ - msg: "kvBatchDelete failed for conn", - err: stringifyError(err), - }); - }); - - // Remove from state and tracking - this.#connections.delete(conn.id); - this.#changedConnections.delete(conn.id); - this.#rLog.debug({ msg: "removed conn", connId: conn.id }); - - // Remove subscriptions - for (const eventName of [...conn.subscriptions.values()]) { - this.#removeSubscription(eventName, conn, true); - } - - this.inspector.emitter.emit("connectionUpdated"); - if (this.#config.onDisconnect) { - try { - const result = this.#config.onDisconnect( - this.actorContext, - conn, - ); - if (result instanceof Promise) { - // Handle promise but don't await it to prevent blocking - result.catch((error) => { - this.#rLog.error({ - msg: "error in `onDisconnect`", - error: stringifyError(error), - }); - }); - } - } catch (error) { - this.#rLog.error({ - msg: "error in `onDisconnect`", - error: stringifyError(error), - }); - } - } - - // Update sleep - this.#resetSleepTimer(); - } - - /** - * Called to create a new connection or reconnect an existing one. - */ - async createConn( - socket: ConnSocket, - // biome-ignore lint/suspicious/noExplicitAny: TypeScript bug with ExtractActorConnParams, - params: any, - request?: Request, - ): Promise> { - this.#assertReady(); - - // TODO: Remove this for ws hibernation v2 since we don't receive an open message for ws - // Check for hibernatable websocket reconnection - if (socket.requestIdBuf && socket.hibernatable) { - this.rLog.debug({ - msg: "checking for hibernatable websocket connection", - requestId: socket.requestId, - existingConnectionsCount: this.#connections.size, - }); - - // Find existing connection with matching hibernatableRequestId - const existingConn = Array.from(this.#connections.values()).find( - (conn) => - conn.__persist.hibernatableRequestId && - arrayBuffersEqual( - conn.__persist.hibernatableRequestId, - socket.requestIdBuf!, - ), - ); - - if (existingConn) { - this.rLog.debug({ - msg: "reconnecting hibernatable websocket connection", - connectionId: existingConn.id, - requestId: socket.requestId, - }); - - // If there's an existing driver state, clean it up without marking as clean disconnect - if (existingConn.__driverState) { - this.#rLog.warn({ - msg: "found existing driver state on hibernatable websocket", - connectionId: existingConn.id, - requestId: socket.requestId, - }); - const driverKind = getConnDriverKindFromState( - existingConn.__driverState, - ); - const driver = CONN_DRIVERS[driverKind]; - if (driver.disconnect) { - // Call driver disconnect to clean up directly. Don't use Conn.disconnect since that will remove the connection entirely. - driver.disconnect( - this, - existingConn, - (existingConn.__driverState as any)[driverKind], - "Reconnecting hibernatable websocket with new driver state", - ); - } - } - - // Update with new driver state - existingConn.__socket = socket; - existingConn.__persist.lastSeen = Date.now(); - - // Update sleep timer since connection is now active - this.#resetSleepTimer(); - - this.inspector.emitter.emit("connectionUpdated"); - - // We don't need to send a new init message since this is a - // hibernated request that has already been initialized - - return existingConn; - } else { - this.rLog.debug({ - msg: "no existing hibernatable connection found, creating new connection", - requestId: socket.requestId, - }); - } - } - - // Prepare connection state - let connState: CS | undefined; - - const onBeforeConnectOpts = { - request, - } satisfies OnConnectOptions; - - if (this.#config.onBeforeConnect) { - await this.#config.onBeforeConnect( - this.actorContext, - onBeforeConnectOpts, - params, - ); - } - - if (this.connStateEnabled) { - if ("createConnState" in this.#config) { - const dataOrPromise = this.#config.createConnState( - this.actorContext as unknown as ActorContext< - undefined, - undefined, - undefined, - undefined, - undefined, - undefined - >, - onBeforeConnectOpts, - params, - ); - if (dataOrPromise instanceof Promise) { - connState = await deadline( - dataOrPromise, - this.#config.options.createConnStateTimeout, - ); - } else { - connState = dataOrPromise; - } - } else if ("connState" in this.#config) { - connState = structuredClone(this.#config.connState); - } else { - throw new Error( - "Could not create connection state from 'createConnState' or 'connState'", - ); - } - } - - // Create connection - const persist: PersistedConn = { - connId: crypto.randomUUID(), - params: params, - state: connState as CS, - lastSeen: Date.now(), - subscriptions: [], - }; - - // Check if this connection is for a hibernatable websocket - if (socket.requestIdBuf) { - const isHibernatable = - this.#persist.hibernatableConns.findIndex((conn) => - arrayBuffersEqual( - conn.hibernatableRequestId, - socket.requestIdBuf!, - ), - ) !== -1; - - if (isHibernatable) { - persist.hibernatableRequestId = socket.requestIdBuf; - } - } - - const conn = new Conn(this, persist); - conn.__socket = socket; - this.#connections.set(conn.id, conn); - - // Update sleep - // - // Do this immediately after adding connection & before any async logic in order to avoid race conditions with sleep timeouts - this.#resetSleepTimer(); - - // Mark connection as changed for batch save - this.#changedConnections.add(conn.id); - - this.saveState({ immediate: true }); - - // Handle connection - if (this.#config.onConnect) { - try { - const result = this.#config.onConnect(this.actorContext, conn); - if (result instanceof Promise) { - deadline( - result, - this.#config.options.onConnectTimeout, - ).catch((error) => { - this.#rLog.error({ - msg: "error in `onConnect`, closing socket", - error, - }); - conn?.disconnect("`onConnect` failed"); - }); - } - } catch (error) { - this.#rLog.error({ - msg: "error in `onConnect`", - error: stringifyError(error), - }); - conn?.disconnect("`onConnect` failed"); - } - } - - this.inspector.emitter.emit("connectionUpdated"); - - // Send init message - conn._sendMessage( - new CachedSerializer( - { - body: { - tag: "Init", - val: { - actorId: this.id, - connectionId: conn.id, - }, - }, - }, - TO_CLIENT_VERSIONED, - ), - ); - - return conn; - } - - // MARK: Messages - async processMessage( - message: protocol.ToServer, - conn: Conn, - ) { - await processMessage(message, this, conn, { - onExecuteAction: async (ctx, name, args) => { - this.inspector.emitter.emit("eventFired", { - type: "action", - name, - args, - connId: conn.id, - }); - return await this.executeAction(ctx, name, args); - }, - onSubscribe: async (eventName, conn) => { - this.inspector.emitter.emit("eventFired", { - type: "subscribe", - eventName, - connId: conn.id, - }); - this.#addSubscription(eventName, conn, false); - }, - onUnsubscribe: async (eventName, conn) => { - this.inspector.emitter.emit("eventFired", { - type: "unsubscribe", - eventName, - connId: conn.id, - }); - this.#removeSubscription(eventName, conn, false); - }, - }); - } - - // MARK: Actions - /** - * Execute an action call from a client. - * - * This method handles: - * 1. Validating the action name - * 2. Executing the action function - * 3. Processing the result through onBeforeActionResponse (if configured) - * 4. Handling timeouts and errors - * 5. Saving state changes - * - * @param ctx The action context - * @param actionName The name of the action being called - * @param args The arguments passed to the action - * @returns The result of the action call - * @throws {ActionNotFound} If the action doesn't exist - * @throws {ActionTimedOut} If the action times out - * @internal - */ - async executeAction( - ctx: ActionContext, - actionName: string, - args: unknown[], - ): Promise { - invariant(this.#ready, "executing action before ready"); - - // Prevent calling private or reserved methods - if (!(actionName in this.#config.actions)) { - this.#rLog.warn({ msg: "action does not exist", actionName }); - throw new errors.ActionNotFound(actionName); - } - - // Check if the method exists on this object - const actionFunction = this.#config.actions[actionName]; - if (typeof actionFunction !== "function") { - this.#rLog.warn({ - msg: "action is not a function", - actionName: actionName, - type: typeof actionFunction, - }); - throw new errors.ActionNotFound(actionName); - } - - // TODO: pass abortable to the action to decide when to abort - // TODO: Manually call abortable for better error handling - // Call the function on this object with those arguments - try { - // Log when we start executing the action - this.#rLog.debug({ - msg: "executing action", - actionName: actionName, - args, - }); - - const outputOrPromise = actionFunction.call( - undefined, - ctx, - ...args, - ); - let output: unknown; - if (outputOrPromise instanceof Promise) { - // Log that we're waiting for an async action - this.#rLog.debug({ - msg: "awaiting async action", - actionName: actionName, - }); - - output = await deadline( - outputOrPromise, - this.#config.options.actionTimeout, - ); - - // Log that async action completed - this.#rLog.debug({ - msg: "async action completed", - actionName: actionName, - }); - } else { - output = outputOrPromise; - } - - // Process the output through onBeforeActionResponse if configured - if (this.#config.onBeforeActionResponse) { - try { - const processedOutput = this.#config.onBeforeActionResponse( - this.actorContext, - actionName, - args, - output, - ); - if (processedOutput instanceof Promise) { - this.#rLog.debug({ - msg: "awaiting onBeforeActionResponse", - actionName: actionName, - }); - output = await processedOutput; - this.#rLog.debug({ - msg: "onBeforeActionResponse completed", - actionName: actionName, - }); - } else { - output = processedOutput; - } - } catch (error) { - this.#rLog.error({ - msg: "error in `onBeforeActionResponse`", - error: stringifyError(error), - }); - } - } - - // Log the output before returning - this.#rLog.debug({ - msg: "action completed", - actionName: actionName, - outputType: typeof output, - isPromise: output instanceof Promise, - }); - - // This output *might* reference a part of the state (using onChange), but - // that's OK since this value always gets serialized and sent over the - // network. - return output; - } catch (error) { - if (error instanceof DeadlineError) { - throw new errors.ActionTimedOut(); - } - this.#rLog.error({ - msg: "action error", - actionName: actionName, - error: stringifyError(error), - }); - throw error; - } finally { - this.#savePersistThrottled(); - } - } - - /** - * Returns a list of action methods available on this actor. - */ - get actions(): string[] { - return Object.keys(this.#config.actions); - } - - /** - * Handles raw HTTP requests to the actor. - */ - async handleFetch( - request: Request, - opts: Record, - ): Promise { - this.#assertReady(); - - if (!this.#config.onFetch) { - throw new errors.FetchHandlerNotDefined(); - } - - try { - const response = await this.#config.onFetch( - this.actorContext, - request, - opts, - ); - if (!response) { - throw new errors.InvalidFetchResponse(); - } - return response; - } catch (error) { - this.#rLog.error({ - msg: "onFetch error", - error: stringifyError(error), - }); - throw error; - } finally { - this.#savePersistThrottled(); - } - } - - /** - * Handles raw WebSocket connections to the actor. - */ - async handleWebSocket( - websocket: UniversalWebSocket, - opts: { request: Request }, - ): Promise { - this.#assertReady(); - - if (!this.#config.onWebSocket) { - throw new errors.InternalError("onWebSocket handler not defined"); - } - - try { - // Set up state tracking to detect changes during WebSocket handling - const stateBeforeHandler = this.#persistChanged; - - // Track active websocket until it fully closes - this.#activeRawWebSockets.add(websocket); - this.#resetSleepTimer(); - - // Track hibernatable WebSockets - let rivetRequestId: ArrayBuffer | undefined; - let persistedHibernatableWebSocket: - | PersistedHibernatableConn - | undefined; - - const onSocketOpened = (event: any) => { - rivetRequestId = event?.rivetRequestId; - - // Find hibernatable WS - if (rivetRequestId) { - const rivetRequestIdLocal = rivetRequestId; - persistedHibernatableWebSocket = - this.#persist.hibernatableConns.find((conn) => - arrayBuffersEqual( - conn.hibernatableRequestId, - rivetRequestIdLocal, - ), - ); - - if (persistedHibernatableWebSocket) { - persistedHibernatableWebSocket.lastSeenTimestamp = - Date.now(); - } - } - - this.#rLog.debug({ - msg: "actor instance onSocketOpened", - rivetRequestId, - isHibernatable: !!persistedHibernatableWebSocket, - hibernationMsgIndex: - persistedHibernatableWebSocket?.msgIndex, - }); - }; - - const onSocketMessage = (event: any) => { - // Update state of hibernatable WS - if (persistedHibernatableWebSocket) { - persistedHibernatableWebSocket.lastSeenTimestamp = - Date.now(); - persistedHibernatableWebSocket.msgIndex = - event.rivetMessageIndex; - } - - this.#rLog.debug({ - msg: "actor instance onSocketMessage", - rivetRequestId, - isHibernatable: !!persistedHibernatableWebSocket, - hibernationMsgIndex: - persistedHibernatableWebSocket?.msgIndex, - }); - }; - - const onSocketClosed = (_event: any) => { - // Remove hibernatable WS - if (rivetRequestId) { - const rivetRequestIdLocal = rivetRequestId; - const wsIndex = this.#persist.hibernatableConns.findIndex( - (conn) => - arrayBuffersEqual( - conn.hibernatableRequestId, - rivetRequestIdLocal, - ), - ); - - const removed = this.#persist.hibernatableConns.splice( - wsIndex, - 1, - ); - if (removed.length > 0) { - this.#rLog.debug({ - msg: "removed hibernatable websocket", - rivetRequestId, - hibernationMsgIndex: - persistedHibernatableWebSocket?.msgIndex, - }); - } else { - this.#rLog.warn({ - msg: "could not find hibernatable websocket to remove", - rivetRequestId, - hibernationMsgIndex: - persistedHibernatableWebSocket?.msgIndex, - }); - } - } - - this.#rLog.debug({ - msg: "actor instance onSocketMessage", - rivetRequestId, - isHibernatable: !!persistedHibernatableWebSocket, - hibernatableWebSocketCount: - this.#persist.hibernatableConns.length, - }); - - // Remove listener and socket from tracking - try { - websocket.removeEventListener("open", onSocketOpened); - websocket.removeEventListener("message", onSocketMessage); - websocket.removeEventListener("close", onSocketClosed); - websocket.removeEventListener("error", onSocketClosed); - } catch {} - this.#activeRawWebSockets.delete(websocket); - this.#resetSleepTimer(); - }; - - try { - websocket.addEventListener("open", onSocketOpened); - websocket.addEventListener("message", onSocketMessage); - websocket.addEventListener("close", onSocketClosed); - websocket.addEventListener("error", onSocketClosed); - } catch {} - - // Handle WebSocket - await this.#config.onWebSocket(this.actorContext, websocket, opts); - - // If state changed during the handler, save it - if (this.#persistChanged && !stateBeforeHandler) { - await this.saveState({ immediate: true }); - } - } catch (error) { - this.#rLog.error({ - msg: "onWebSocket error", - error: stringifyError(error), - }); - throw error; - } finally { - this.#savePersistThrottled(); - } - } - - // MARK: Events - #addSubscription( - eventName: string, - connection: Conn, - fromPersist: boolean, - ) { - if (connection.subscriptions.has(eventName)) { - this.#rLog.debug({ - msg: "connection already has subscription", - eventName, - }); - return; - } - - // Persist subscriptions & save immediately - // - // Don't update persistence if already restoring from persistence - if (!fromPersist) { - connection.__persist.subscriptions.push({ eventName: eventName }); - - // Mark connection as changed - this.#changedConnections.add(connection.id); - - this.saveState({ immediate: true }); - } - - // Update subscriptions - connection.subscriptions.add(eventName); - - // Update subscription index - let subscribers = this.#subscriptionIndex.get(eventName); - if (!subscribers) { - subscribers = new Set(); - this.#subscriptionIndex.set(eventName, subscribers); - } - subscribers.add(connection); - } - - #removeSubscription( - eventName: string, - connection: Conn, - fromRemoveConn: boolean, - ) { - if (!connection.subscriptions.has(eventName)) { - this.#rLog.warn({ - msg: "connection does not have subscription", - eventName, - }); - return; - } - - // Persist subscriptions & save immediately - // - // Don't update the connection itself if the connection is already being removed - if (!fromRemoveConn) { - connection.subscriptions.delete(eventName); - - const subIdx = connection.__persist.subscriptions.findIndex( - (s) => s.eventName === eventName, - ); - if (subIdx !== -1) { - connection.__persist.subscriptions.splice(subIdx, 1); - } else { - this.#rLog.warn({ - msg: "subscription does not exist with name", - eventName, - }); - } - - // Mark connection as changed - this.#changedConnections.add(connection.id); - - this.saveState({ immediate: true }); - } - - // Update scriptions index - const subscribers = this.#subscriptionIndex.get(eventName); - if (subscribers) { - subscribers.delete(connection); - if (subscribers.size === 0) { - this.#subscriptionIndex.delete(eventName); - } - } - } - - /** - * Broadcasts an event to all connected clients. - * @param name - The name of the event. - * @param args - The arguments to send with the event. - */ - _broadcast>(name: string, ...args: Args) { - this.#assertReady(); - - this.inspector.emitter.emit("eventFired", { - type: "broadcast", - eventName: name, - args, - }); - - // Send to all connected clients - const subscriptions = this.#subscriptionIndex.get(name); - if (!subscriptions) return; - - const toClientSerializer = new CachedSerializer( - { - body: { - tag: "Event", - val: { - name, - args: bufferToArrayBuffer(cbor.encode(args)), - }, - }, - }, - TO_CLIENT_VERSIONED, - ); - - // Send message to clients - for (const connection of subscriptions) { - connection._sendMessage(toClientSerializer); - } - } - - // MARK: Alarms - async #scheduleEventInner(newEvent: PersistedScheduleEvent) { - this.actorContext.log.info({ msg: "scheduling event", ...newEvent }); - - // Insert event in to index - const insertIndex = this.#persist.scheduledEvents.findIndex( - (x) => x.timestamp > newEvent.timestamp, - ); - if (insertIndex === -1) { - this.#persist.scheduledEvents.push(newEvent); - } else { - this.#persist.scheduledEvents.splice(insertIndex, 0, newEvent); - } - - // Update alarm if: - // - this is the newest event (i.e. at beginning of array) or - // - this is the only event (i.e. the only event in the array) - if (insertIndex === 0 || this.#persist.scheduledEvents.length === 1) { - this.actorContext.log.info({ - msg: "setting alarm", - timestamp: newEvent.timestamp, - eventCount: this.#persist.scheduledEvents.length, - }); - await this.#queueSetAlarm(newEvent.timestamp); - } - } - - async scheduleEvent( - timestamp: number, - action: string, - args: unknown[], - ): Promise { - return this.#scheduleEventInner({ - eventId: crypto.randomUUID(), - timestamp, - action, - args: bufferToArrayBuffer(cbor.encode(args)), - }); - } - - /** - * Triggers any pending alarms. - * - * This method is idempotent. It's called automatically when the actor wakes - * in order to trigger any pending alarms. - */ - async _onAlarm() { - const now = Date.now(); - this.actorContext.log.debug({ - msg: "alarm triggered", - now, - events: this.#persist.scheduledEvents.length, - }); - - // Update sleep - // - // Do this before any async logic - this.#resetSleepTimer(); - - // Remove events from schedule that we're about to run - const runIndex = this.#persist.scheduledEvents.findIndex( - (x) => x.timestamp <= now, - ); - if (runIndex === -1) { - // This method is idempotent, so this will happen in scenarios like `start` and - // no events are pending. - this.#rLog.debug({ msg: "no events are due yet" }); - if (this.#persist.scheduledEvents.length > 0) { - const nextTs = this.#persist.scheduledEvents[0].timestamp; - this.actorContext.log.debug({ - msg: "alarm fired early, rescheduling for next event", - now, - nextTs, - delta: nextTs - now, - }); - await this.#queueSetAlarm(nextTs); - } - this.actorContext.log.debug({ msg: "no events to run", now }); - return; - } - const scheduleEvents = this.#persist.scheduledEvents.splice( - 0, - runIndex + 1, - ); - this.actorContext.log.debug({ - msg: "running events", - count: scheduleEvents.length, - }); - - // Set alarm for next event - if (this.#persist.scheduledEvents.length > 0) { - const nextTs = this.#persist.scheduledEvents[0].timestamp; - this.actorContext.log.info({ - msg: "setting next alarm", - nextTs, - remainingEvents: this.#persist.scheduledEvents.length, - }); - await this.#queueSetAlarm(nextTs); - } - - // Iterate by event key in order to ensure we call the events in order - for (const event of scheduleEvents) { - try { - this.actorContext.log.info({ - msg: "running action for event", - event: event.eventId, - timestamp: event.timestamp, - action: event.action, - }); - - // Look up function - const fn: unknown = this.#config.actions[event.action]; - - if (!fn) - throw new Error(`Missing action for alarm ${event.action}`); - if (typeof fn !== "function") - throw new Error( - `Alarm function lookup for ${event.action} returned ${typeof fn}`, - ); - - // Call function - try { - const args = event.args - ? cbor.decode(new Uint8Array(event.args)) - : []; - await fn.call(undefined, this.actorContext, ...args); - } catch (error) { - this.actorContext.log.error({ - msg: "error while running event", - error: stringifyError(error), - event: event.eventId, - timestamp: event.timestamp, - action: event.action, - }); - } - } catch (error) { - this.actorContext.log.error({ - msg: "internal error while running event", - error: stringifyError(error), - ...event, - }); - } - } - } - - async #queueSetAlarm(timestamp: number): Promise { - await this.#alarmWriteQueue.enqueue(async () => { - await this.#actorDriver.setAlarm(this, timestamp); - }); - } - - // MARK: Background Promises - /** Wait for background waitUntil promises with a timeout. */ - async #waitBackgroundPromises(timeoutMs: number) { - const pending = this.#backgroundPromises; - if (pending.length === 0) { - this.#rLog.debug({ msg: "no background promises" }); - return; - } - - // Race promises with timeout to determine if pending promises settled fast enough - const timedOut = await Promise.race([ - Promise.allSettled(pending).then(() => false), - new Promise((resolve) => - setTimeout(() => resolve(true), timeoutMs), - ), - ]); - - if (timedOut) { - this.#rLog.error({ - msg: "timed out waiting for background tasks, background promises may have leaked", - count: pending.length, - timeoutMs, - }); - } else { - this.#rLog.debug({ msg: "background promises finished" }); - } - } - - /** - * Prevents the actor from sleeping until promise is complete. - * - * This allows the actor runtime to ensure that a promise completes while - * returning from an action request early. - * - * @param promise - The promise to run in the background. - */ - _waitUntil(promise: Promise) { - this.#assertReady(); - - // TODO: Should we force save the state? - // Add logging to promise and make it non-failable - const nonfailablePromise = promise - .then(() => { - this.#rLog.debug({ msg: "wait until promise complete" }); - }) - .catch((error) => { - this.#rLog.error({ - msg: "wait until promise failed", - error: stringifyError(error), - }); - }); - this.#backgroundPromises.push(nonfailablePromise); - } - - // MARK: BARE Conversion Helpers - #convertToBarePersisted( - persist: PersistedActor, - ): persistSchema.Actor { - // Convert hibernatable connections from the in-memory connections map - // Convert hibernatableConns from the persisted structure - const hibernatableConns: persistSchema.HibernatableConn[] = - persist.hibernatableConns.map((conn) => ({ - id: conn.id, - parameters: bufferToArrayBuffer( - cbor.encode(conn.parameters || {}), - ), - state: bufferToArrayBuffer(cbor.encode(conn.state || {})), - subscriptions: conn.subscriptions.map((sub) => ({ - eventName: sub.eventName, - })), - hibernatableRequestId: conn.hibernatableRequestId, - lastSeenTimestamp: BigInt(conn.lastSeenTimestamp), - msgIndex: BigInt(conn.msgIndex), - })); - - return { - input: - persist.input !== undefined - ? bufferToArrayBuffer(cbor.encode(persist.input)) - : null, - hasInitialized: persist.hasInitialized, - state: bufferToArrayBuffer(cbor.encode(persist.state)), - hibernatableConns, - scheduledEvents: persist.scheduledEvents.map((event) => ({ - eventId: event.eventId, - timestamp: BigInt(event.timestamp), - action: event.action, - args: event.args ?? null, - })), - }; - } - - #convertFromBarePersisted( - bareData: persistSchema.Actor, - ): PersistedActor { - // Convert hibernatableConns from the BARE schema format - const hibernatableConns: PersistedHibernatableConn[] = - bareData.hibernatableConns.map((conn) => ({ - id: conn.id, - parameters: cbor.decode(new Uint8Array(conn.parameters)) as CP, - state: cbor.decode(new Uint8Array(conn.state)) as CS, - subscriptions: conn.subscriptions.map((sub) => ({ - eventName: sub.eventName, - })), - hibernatableRequestId: conn.hibernatableRequestId, - lastSeenTimestamp: Number(conn.lastSeenTimestamp), - msgIndex: Number(conn.msgIndex), - })); - - return { - input: bareData.input - ? cbor.decode(new Uint8Array(bareData.input)) - : undefined, - hasInitialized: bareData.hasInitialized, - state: cbor.decode(new Uint8Array(bareData.state)), - hibernatableConns, - scheduledEvents: bareData.scheduledEvents.map((event) => ({ - eventId: event.eventId, - timestamp: Number(event.timestamp), - action: event.action, - args: event.args ?? undefined, - })), - }; - } -} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts new file mode 100644 index 0000000000..23d6481ed0 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts @@ -0,0 +1,403 @@ +import * as cbor from "cbor-x"; +import { arrayBuffersEqual, idToStr, stringifyError } from "@/utils"; +import type { OnConnectOptions } from "../config"; +import type { ConnDriver } from "../conn/driver"; +import { + CONN_DRIVER_SYMBOL, + CONN_PERSIST_SYMBOL, + Conn, + type ConnId, +} from "../conn/mod"; +import type { AnyDatabaseProvider } from "../database"; +import type { ActorDriver } from "../driver"; +import { deadline } from "../utils"; +import { makeConnKey } from "./kv"; +import { ACTOR_INSTANCE_PERSIST_SYMBOL, type ActorInstance } from "./mod"; +import type { PersistedConn } from "./persisted"; + +/** + * Manages all connection-related operations for an actor instance. + * Handles connection creation, tracking, hibernation, and cleanup. + */ +export class ConnectionManager< + S, + CP, + CS, + V, + I, + DB extends AnyDatabaseProvider, +> { + #actor: ActorInstance; + #connections = new Map>(); + #changedConnections = new Set(); + + constructor(actor: ActorInstance) { + this.#actor = actor; + } + + // MARK: - Public API + + get connections(): Map> { + return this.#connections; + } + + get changedConnections(): Set { + return this.#changedConnections; + } + + clearChangedConnections() { + this.#changedConnections.clear(); + } + + getConnForId(id: string): Conn | undefined { + return this.#connections.get(id); + } + + markConnChanged(conn: Conn) { + this.#changedConnections.add(conn.id); + this.#actor.rLog.debug({ + msg: "marked connection as changed", + connId: conn.id, + totalChanged: this.#changedConnections.size, + }); + } + + // MARK: - Connection Lifecycle + + /** + * Creates a new connection or reconnects an existing hibernatable connection. + */ + async createConn( + driver: ConnDriver, + params: CP, + request?: Request, + ): Promise> { + // Check for hibernatable websocket reconnection + if (driver.requestIdBuf && driver.hibernatable) { + const existingConn = this.#findHibernatableConn( + driver.requestIdBuf, + ); + + if (existingConn) { + return this.#reconnectHibernatableConn(existingConn, driver); + } + } + + // Create new connection + return await this.#createNewConn(driver, params, request); + } + + /** + * Handle connection disconnection. + * Clean disconnects remove the connection immediately. + * Unclean disconnects keep the connection for potential reconnection. + */ + async connDisconnected( + conn: Conn, + wasClean: boolean, + actorDriver: ActorDriver, + eventManager: any, // EventManager type + ) { + if (wasClean) { + // Clean disconnect - remove immediately + await this.removeConn(conn, actorDriver, eventManager); + } else { + // Unclean disconnect - keep for reconnection + this.#handleUncleanDisconnect(conn); + } + } + + /** + * Removes a connection and cleans up its resources. + */ + async removeConn( + conn: Conn, + actorDriver: ActorDriver, + eventManager: any, // EventManager type + ) { + // Remove from KV storage + const key = makeConnKey(conn.id); + try { + await actorDriver.kvBatchDelete(this.#actor.id, [key]); + this.#actor.rLog.debug({ + msg: "removed connection from KV", + connId: conn.id, + }); + } catch (err) { + this.#actor.rLog.error({ + msg: "kvBatchDelete failed for conn", + err: stringifyError(err), + }); + } + + // Remove from tracking + this.#connections.delete(conn.id); + this.#changedConnections.delete(conn.id); + this.#actor.rLog.debug({ msg: "removed conn", connId: conn.id }); + + // Clean up subscriptions via EventManager + if (eventManager) { + for (const eventName of [...conn.subscriptions.values()]) { + eventManager.removeSubscription(eventName, conn, true); + } + } + + // Emit events and call lifecycle hooks + this.#actor.inspector.emitter.emit("connectionUpdated"); + + const config = (this.#actor as any).config; + if (config?.onDisconnect) { + try { + const result = config.onDisconnect( + this.#actor.actorContext, + conn, + ); + if (result instanceof Promise) { + result.catch((error: any) => { + this.#actor.rLog.error({ + msg: "error in `onDisconnect`", + error: stringifyError(error), + }); + }); + } + } catch (error) { + this.#actor.rLog.error({ + msg: "error in `onDisconnect`", + error: stringifyError(error), + }); + } + } + } + + // MARK: - Persistence + + /** + * Restores connections from persisted data during actor initialization. + */ + restoreConnections( + connections: PersistedConn[], + eventManager: any, // EventManager type + ) { + for (const connPersist of connections) { + // Create connection instance + const conn = new Conn( + this.#actor, + connPersist, + ); + this.#connections.set(conn.id, conn); + + // Restore subscriptions + for (const sub of connPersist.subscriptions) { + eventManager.addSubscription(sub.eventName, conn, true); + } + } + } + + /** + * Gets persistence data for all changed connections. + */ + getChangedConnectionsData(): Array<[Uint8Array, Uint8Array]> { + const entries: Array<[Uint8Array, Uint8Array]> = []; + + for (const connId of this.#changedConnections) { + const conn = this.#connections.get(connId); + if (conn) { + const connData = cbor.encode(conn.persistRaw); + entries.push([makeConnKey(connId), connData]); + conn.markSaved(); + } + } + + return entries; + } + + // MARK: - Private Helpers + + #findHibernatableConn( + requestIdBuf: ArrayBuffer, + ): Conn | undefined { + return Array.from(this.#connections.values()).find( + (conn) => + conn[CONN_PERSIST_SYMBOL].hibernatableRequestId && + arrayBuffersEqual( + conn[CONN_PERSIST_SYMBOL].hibernatableRequestId, + requestIdBuf, + ), + ); + } + + #reconnectHibernatableConn( + existingConn: Conn, + driver: ConnDriver, + ): Conn { + this.#actor.rLog.debug({ + msg: "reconnecting hibernatable websocket connection", + connectionId: existingConn.id, + requestId: driver.requestId, + }); + + // Clean up existing driver state if present + if (existingConn[CONN_DRIVER_SYMBOL]) { + this.#cleanupDriverState(existingConn); + } + + // Update connection with new socket + existingConn[CONN_DRIVER_SYMBOL] = driver; + existingConn[CONN_PERSIST_SYMBOL].lastSeen = Date.now(); + + this.#actor.inspector.emitter.emit("connectionUpdated"); + + return existingConn; + } + + #cleanupDriverState(conn: Conn) { + const driver = conn[CONN_DRIVER_SYMBOL]; + if (driver?.disconnect) { + driver.disconnect( + this.#actor, + conn, + "Reconnecting hibernatable websocket with new driver state", + ); + } + } + + async #createNewConn( + driver: ConnDriver, + params: CP, + request: Request | undefined, + ): Promise> { + const config = this.#actor.config; + const persist = (this.#actor as any)[ACTOR_INSTANCE_PERSIST_SYMBOL]; + // Prepare connection state + let connState: CS | undefined; + + const onBeforeConnectOpts = { + request, + } satisfies OnConnectOptions; + + // Call onBeforeConnect hook + if (config.onBeforeConnect) { + await config.onBeforeConnect( + this.#actor.actorContext, + onBeforeConnectOpts, + params, + ); + } + + // Create connection state if enabled + if ((this.#actor as any).connStateEnabled) { + connState = await this.#createConnState( + config, + onBeforeConnectOpts, + params, + ); + } + + // Create connection persist data + const connPersist: PersistedConn = { + connId: crypto.randomUUID(), + params: params, + state: connState as CS, + lastSeen: Date.now(), + subscriptions: [], + }; + + // Check if hibernatable + if (driver.requestIdBuf) { + const isHibernatable = this.#isHibernatableRequest( + driver.requestIdBuf, + persist, + ); + if (isHibernatable) { + connPersist.hibernatableRequestId = driver.requestIdBuf; + } + } + + // Create connection instance + const conn = new Conn(this.#actor, connPersist); + conn[CONN_DRIVER_SYMBOL] = driver; + this.#connections.set(conn.id, conn); + + // Mark as changed for persistence + this.#changedConnections.add(conn.id); + + // Call onConnect lifecycle hook + if (config.onConnect) { + this.#callOnConnect(config, conn); + } + + this.#actor.inspector.emitter.emit("connectionUpdated"); + + return conn; + } + + async #createConnState( + config: any, + opts: OnConnectOptions, + params: CP, + ): Promise { + if ("createConnState" in config) { + const dataOrPromise = config.createConnState( + this.#actor.actorContext, + opts, + params, + ); + if (dataOrPromise instanceof Promise) { + return await deadline( + dataOrPromise, + config.options.createConnStateTimeout, + ); + } + return dataOrPromise; + } else if ("connState" in config) { + return structuredClone(config.connState); + } + + throw new Error( + "Could not create connection state from 'createConnState' or 'connState'", + ); + } + + #isHibernatableRequest(requestIdBuf: ArrayBuffer, persist: any): boolean { + return ( + persist.hibernatableConns.findIndex((conn: any) => + arrayBuffersEqual(conn.hibernatableRequestId, requestIdBuf), + ) !== -1 + ); + } + + #callOnConnect(config: any, conn: Conn) { + try { + const result = config.onConnect(this.#actor.actorContext, conn); + if (result instanceof Promise) { + deadline(result, config.options.onConnectTimeout).catch( + (error: any) => { + this.#actor.rLog.error({ + msg: "error in `onConnect`, closing socket", + error, + }); + conn?.disconnect("`onConnect` failed"); + }, + ); + } + } catch (error) { + this.#actor.rLog.error({ + msg: "error in `onConnect`", + error: stringifyError(error), + }); + conn?.disconnect("`onConnect` failed"); + } + } + + #handleUncleanDisconnect(conn: Conn) { + if (!conn[CONN_DRIVER_SYMBOL]) { + this.#actor.rLog.warn("called conn disconnected without driver"); + } + + // Update last seen for cleanup tracking + conn[CONN_PERSIST_SYMBOL].lastSeen = Date.now(); + + // Remove socket + conn[CONN_DRIVER_SYMBOL] = undefined; + } +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts new file mode 100644 index 0000000000..944617e002 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts @@ -0,0 +1,281 @@ +import * as cbor from "cbor-x"; +import type * as protocol from "@/schemas/client-protocol/mod"; +import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; +import { bufferToArrayBuffer } from "@/utils"; +import { CONN_PERSIST_SYMBOL, type Conn } from "../conn/mod"; +import type { AnyDatabaseProvider } from "../database"; +import { CachedSerializer } from "../protocol/serde"; +import type { ActorInstance } from "./mod"; + +/** + * Manages event subscriptions and broadcasting for actor instances. + * Handles subscription tracking and efficient message distribution to connected clients. + */ +export class EventManager { + #actor: ActorInstance; + #subscriptionIndex = new Map>>(); + + constructor(actor: ActorInstance) { + this.#actor = actor; + } + + // MARK: - Public API + + /** + * Adds a subscription for a connection to an event. + * + * @param eventName - The name of the event to subscribe to + * @param connection - The connection subscribing to the event + * @param fromPersist - Whether this subscription is being restored from persistence + */ + addSubscription( + eventName: string, + connection: Conn, + fromPersist: boolean, + ) { + // Check if already subscribed + if (connection.subscriptions.has(eventName)) { + this.#actor.rLog.debug({ + msg: "connection already has subscription", + eventName, + connId: connection.id, + }); + return; + } + + // Update connection's subscription list + connection.subscriptions.add(eventName); + + // Update subscription index + let subscribers = this.#subscriptionIndex.get(eventName); + if (!subscribers) { + subscribers = new Set(); + this.#subscriptionIndex.set(eventName, subscribers); + } + subscribers.add(connection); + + // Persist subscription if not restoring from persistence + if (!fromPersist) { + connection[CONN_PERSIST_SYMBOL].subscriptions.push({ eventName }); + + // Mark connection as changed for persistence + const connectionManager = (this.#actor as any).connectionManager; + if (connectionManager) { + connectionManager.markConnChanged(connection); + } + + // Save state immediately + const stateManager = (this.#actor as any).stateManager; + if (stateManager) { + stateManager.saveState({ immediate: true }); + } + } + + this.#actor.rLog.debug({ + msg: "subscription added", + eventName, + connId: connection.id, + totalSubscribers: subscribers.size, + }); + } + + /** + * Removes a subscription for a connection from an event. + * + * @param eventName - The name of the event to unsubscribe from + * @param connection - The connection unsubscribing from the event + * @param fromRemoveConn - Whether this is being called as part of connection removal + */ + removeSubscription( + eventName: string, + connection: Conn, + fromRemoveConn: boolean, + ) { + // Check if subscription exists + if (!connection.subscriptions.has(eventName)) { + this.#actor.rLog.warn({ + msg: "connection does not have subscription", + eventName, + connId: connection.id, + }); + return; + } + + // Remove from connection's subscription list + connection.subscriptions.delete(eventName); + + // Update subscription index + const subscribers = this.#subscriptionIndex.get(eventName); + if (subscribers) { + subscribers.delete(connection); + if (subscribers.size === 0) { + this.#subscriptionIndex.delete(eventName); + } + } + + // Update persistence if not part of connection removal + if (!fromRemoveConn) { + // Remove from persisted subscriptions + const subIdx = connection[ + CONN_PERSIST_SYMBOL + ].subscriptions.findIndex((s) => s.eventName === eventName); + if (subIdx !== -1) { + connection[CONN_PERSIST_SYMBOL].subscriptions.splice(subIdx, 1); + } else { + this.#actor.rLog.warn({ + msg: "subscription does not exist in persist", + eventName, + connId: connection.id, + }); + } + + // Mark connection as changed for persistence + const connectionManager = (this.#actor as any).connectionManager; + if (connectionManager) { + connectionManager.markConnChanged(connection); + } + + // Save state immediately + const stateManager = (this.#actor as any).stateManager; + if (stateManager) { + stateManager.saveState({ immediate: true }); + } + } + + this.#actor.rLog.debug({ + msg: "subscription removed", + eventName, + connId: connection.id, + remainingSubscribers: subscribers?.size || 0, + }); + } + + /** + * Broadcasts an event to all subscribed connections. + * + * @param name - The name of the event to broadcast + * @param args - The arguments to send with the event + */ + broadcast>(name: string, ...args: Args) { + // Emit to inspector + this.#actor.inspector.emitter.emit("eventFired", { + type: "broadcast", + eventName: name, + args, + }); + + // Get subscribers for this event + const subscribers = this.#subscriptionIndex.get(name); + if (!subscribers || subscribers.size === 0) { + this.#actor.rLog.debug({ + msg: "no subscribers for event", + eventName: name, + }); + return; + } + + // Create serialized message + const toClientSerializer = new CachedSerializer( + { + body: { + tag: "Event", + val: { + name, + args: bufferToArrayBuffer(cbor.encode(args)), + }, + }, + }, + TO_CLIENT_VERSIONED, + ); + + // Send to all subscribers + let sentCount = 0; + for (const connection of subscribers) { + try { + connection.sendMessage(toClientSerializer); + sentCount++; + } catch (error) { + this.#actor.rLog.error({ + msg: "failed to send event to connection", + eventName: name, + connId: connection.id, + error: + error instanceof Error ? error.message : String(error), + }); + } + } + + this.#actor.rLog.debug({ + msg: "event broadcasted", + eventName: name, + subscriberCount: subscribers.size, + sentCount, + }); + } + + /** + * Gets all subscribers for a specific event. + * + * @param eventName - The name of the event + * @returns Set of connections subscribed to the event, or undefined if no subscribers + */ + getSubscribers( + eventName: string, + ): Set> | undefined { + return this.#subscriptionIndex.get(eventName); + } + + /** + * Gets all events and their subscriber counts. + * + * @returns Map of event names to subscriber counts + */ + getEventStats(): Map { + const stats = new Map(); + for (const [eventName, subscribers] of this.#subscriptionIndex) { + stats.set(eventName, subscribers.size); + } + return stats; + } + + /** + * Clears all subscriptions for a connection. + * Used during connection cleanup. + * + * @param connection - The connection to clear subscriptions for + */ + clearConnectionSubscriptions(connection: Conn) { + for (const eventName of [...connection.subscriptions.values()]) { + this.removeSubscription(eventName, connection, true); + } + } + + /** + * Gets the total number of unique events being subscribed to. + */ + get eventCount(): number { + return this.#subscriptionIndex.size; + } + + /** + * Gets the total number of subscriptions across all events. + */ + get totalSubscriptionCount(): number { + let total = 0; + for (const subscribers of this.#subscriptionIndex.values()) { + total += subscribers.size; + } + return total; + } + + /** + * Checks if an event has any subscribers. + * + * @param eventName - The name of the event to check + * @returns True if the event has at least one subscriber + */ + hasSubscribers(eventName: string): boolean { + const subscribers = this.#subscriptionIndex.get(eventName); + return subscribers !== undefined && subscribers.size > 0; + } +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/kv.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/kv.ts similarity index 100% rename from rivetkit-typescript/packages/rivetkit/src/actor/kv.ts rename to rivetkit-typescript/packages/rivetkit/src/actor/instance/kv.ts diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts new file mode 100644 index 0000000000..f7ea90be49 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts @@ -0,0 +1,1027 @@ +import * as cbor from "cbor-x"; +import invariant from "invariant"; +import type { ActorKey } from "@/actor/mod"; +import type { Client } from "@/client/client"; +import { getBaseLogger, getIncludeTarget, type Logger } from "@/common/log"; +import { stringifyError } from "@/common/utils"; +import type { UniversalWebSocket } from "@/common/websocket-interface"; +import { ActorInspector } from "@/inspector/actor"; +import type { Registry } from "@/mod"; +import { ACTOR_VERSIONED } from "@/schemas/actor-persist/versioned"; +import type * as protocol from "@/schemas/client-protocol/mod"; +import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; +import { EXTRA_ERROR_LOG, idToStr } from "@/utils"; +import type { ActorConfig, InitContext } from "../config"; +import type { ConnDriver } from "../conn/driver"; +import { createHttpSocket } from "../conn/drivers/http"; +import { CONN_PERSIST_SYMBOL, type Conn, type ConnId } from "../conn/mod"; +import { ActionContext } from "../contexts/action"; +import { ActorContext } from "../contexts/actor"; +import type { AnyDatabaseProvider, InferDatabaseClient } from "../database"; +import type { ActorDriver } from "../driver"; +import * as errors from "../errors"; +import { serializeActorKey } from "../keys"; +import { processMessage } from "../protocol/old"; +import { CachedSerializer } from "../protocol/serde"; +import { Schedule } from "../schedule"; +import { DeadlineError, deadline } from "../utils"; +import { ConnectionManager } from "./connection-manager"; +import { EventManager } from "./event-manager"; +import { KEYS } from "./kv"; +import type { PersistedActor, PersistedConn } from "./persisted"; +import { ScheduleManager } from "./schedule-manager"; +import { type SaveStateOptions, StateManager } from "./state-manager"; + +export type { SaveStateOptions }; + +export const ACTOR_INSTANCE_PERSIST_SYMBOL = Symbol("persist"); + +enum CanSleep { + Yes, + NotReady, + ActiveConns, + ActiveHonoHttpRequests, + ActiveRawWebSockets, +} + +/** Actor type alias with all `any` types. Used for `extends` in classes referencing this actor. */ +export type AnyActorInstance = ActorInstance; + +export type ExtractActorState = + A extends ActorInstance + ? State + : never; + +export type ExtractActorConnParams = + A extends ActorInstance + ? ConnParams + : never; + +export type ExtractActorConnState = + A extends ActorInstance + ? ConnState + : never; + +// MARK: - Main ActorInstance Class +export class ActorInstance { + // MARK: - Core Properties + actorContext: ActorContext; + #config: ActorConfig; + #actorDriver!: ActorDriver; + #inlineClient!: Client>; + #actorId!: string; + #name!: string; + #key!: ActorKey; + #region!: string; + + // MARK: - Managers + #connectionManager!: ConnectionManager; + #stateManager!: StateManager; + #eventManager!: EventManager; + #scheduleManager!: ScheduleManager; + + // MARK: - Logging + #log!: Logger; + #rLog!: Logger; + + // MARK: - Lifecycle State + #ready = false; + #sleepCalled = false; + #stopCalled = false; + #sleepTimeout?: NodeJS.Timeout; + #abortController = new AbortController(); + + // MARK: - Variables & Database + #vars?: V; + #db!: InferDatabaseClient; + + // MARK: - Background Tasks + #backgroundPromises: Promise[] = []; + + // MARK: - HTTP/WebSocket Tracking + #activeHonoHttpRequests = 0; + #activeRawWebSockets = new Set(); + + // MARK: - Deprecated (kept for compatibility) + #schedule!: Schedule; + + // MARK: - Inspector + #inspector = new ActorInspector(() => { + return { + isDbEnabled: async () => { + return this.#db !== undefined; + }, + getDb: async () => { + return this.db; + }, + isStateEnabled: async () => { + return this.stateEnabled; + }, + getState: async () => { + if (!this.stateEnabled) { + throw new errors.StateNotEnabled(); + } + return this.#stateManager.persistRaw.state as Record< + string, + any + > as unknown; + }, + getRpcs: async () => { + return Object.keys(this.#config.actions); + }, + getConnections: async () => { + return Array.from( + this.#connectionManager.connections.entries(), + ).map(([id, conn]) => ({ + id, + params: conn.params as any, + state: conn.stateEnabled ? conn.state : undefined, + subscriptions: conn.subscriptions.size, + lastSeen: conn.lastSeen, + stateEnabled: conn.stateEnabled, + isHibernatable: conn.isHibernatable, + hibernatableRequestId: conn[CONN_PERSIST_SYMBOL] + .hibernatableRequestId + ? idToStr( + conn[CONN_PERSIST_SYMBOL].hibernatableRequestId, + ) + : undefined, + })); + }, + setState: async (state: unknown) => { + if (!this.stateEnabled) { + throw new errors.StateNotEnabled(); + } + this.#stateManager.state = { ...(state as S) }; + await this.#stateManager.saveState({ immediate: true }); + }, + executeAction: async (name, params) => { + const conn = await this.createConn( + createHttpSocket(), + undefined, + undefined, + ); + + try { + return await this.executeAction( + new ActionContext(this.actorContext, conn), + name, + params || [], + ); + } finally { + this.connDisconnected(conn, true); + } + }, + }; + }); + + // MARK: - Constructor + constructor(config: ActorConfig) { + this.#config = config; + this.actorContext = new ActorContext(this); + } + + // MARK: - Public Getters + get log(): Logger { + invariant(this.#log, "log not configured"); + return this.#log; + } + + get rLog(): Logger { + invariant(this.#rLog, "log not configured"); + return this.#rLog; + } + + get isStopping(): boolean { + return this.#stopCalled; + } + + get id(): string { + return this.#actorId; + } + + get name(): string { + return this.#name; + } + + get key(): ActorKey { + return this.#key; + } + + get region(): string { + return this.#region; + } + + get inlineClient(): Client> { + return this.#inlineClient; + } + + get inspector(): ActorInspector { + return this.#inspector; + } + + get conns(): Map> { + return this.#connectionManager.connections; + } + + get schedule(): Schedule { + return this.#schedule; + } + + get abortSignal(): AbortSignal { + return this.#abortController.signal; + } + + get actions(): string[] { + return Object.keys(this.#config.actions); + } + + get config(): ActorConfig { + return this.#config; + } + + // MARK: - State Access + get [ACTOR_INSTANCE_PERSIST_SYMBOL](): PersistedActor { + return this.#stateManager.persist; + } + + get state(): S { + return this.#stateManager.state; + } + + set state(value: S) { + this.#stateManager.state = value; + } + + get stateEnabled(): boolean { + return this.#stateManager.stateEnabled; + } + + get connStateEnabled(): boolean { + return "createConnState" in this.#config || "connState" in this.#config; + } + + // MARK: - Variables & Database + get vars(): V { + this.#validateVarsEnabled(); + invariant(this.#vars !== undefined, "vars not enabled"); + return this.#vars; + } + + get db(): InferDatabaseClient { + if (!this.#db) { + throw new errors.DatabaseNotEnabled(); + } + return this.#db; + } + + // MARK: - Initialization + async start( + actorDriver: ActorDriver, + inlineClient: Client>, + actorId: string, + name: string, + key: ActorKey, + region: string, + ) { + // Initialize properties + this.#actorDriver = actorDriver; + this.#inlineClient = inlineClient; + this.#actorId = actorId; + this.#name = name; + this.#key = key; + this.#region = region; + + // Initialize logging + this.#initializeLogging(); + + // Initialize managers + this.#connectionManager = new ConnectionManager(this); + this.#stateManager = new StateManager(this, actorDriver, this.#config); + this.#eventManager = new EventManager(this); + this.#scheduleManager = new ScheduleManager( + this, + actorDriver, + this.#config, + ); + + // Legacy schedule object (for compatibility) + this.#schedule = new Schedule(this); + + // Read and initialize state + await this.#initializeState(); + + // Initialize variables + if (this.#varsEnabled) { + await this.#initializeVars(); + } + + // Call onStart lifecycle + await this.#callOnStart(); + + // Setup database + await this.#setupDatabase(); + + // Initialize alarms + await this.#scheduleManager.initializeAlarms(); + + // Mark as ready + this.#ready = true; + this.#rLog.info({ msg: "actor ready" }); + + // Start sleep timer + this.#resetSleepTimer(); + + // Trigger any pending alarms + await this.onAlarm(); + } + + // MARK: - Ready Check + isReady(): boolean { + return this.#ready; + } + + #assertReady(allowStoppingState: boolean = false) { + if (!this.#ready) throw new errors.InternalError("Actor not ready"); + if (!allowStoppingState && this.#stopCalled) + throw new errors.InternalError("Actor is stopping"); + } + + // MARK: - Stop + async onStop() { + if (this.#stopCalled) { + this.#rLog.warn({ msg: "already stopping actor" }); + return; + } + this.#stopCalled = true; + this.#rLog.info({ msg: "actor stopping" }); + + // Clear sleep timeout + if (this.#sleepTimeout) { + clearTimeout(this.#sleepTimeout); + this.#sleepTimeout = undefined; + } + + // Abort listeners + try { + this.#abortController.abort(); + } catch {} + + // Call onStop lifecycle + await this.#callOnStop(); + + // Disconnect non-hibernatable connections + await this.#disconnectConnections(); + + // Wait for background tasks + await this.#waitBackgroundPromises( + this.#config.options.waitUntilTimeout, + ); + + // Clear timeouts and save state + this.#stateManager.clearPendingSaveTimeout(); + await this.saveState({ immediate: true, allowStoppingState: true }); + + // Wait for write queues + await this.#stateManager.waitForPendingWrites(); + await this.#scheduleManager.waitForPendingAlarmWrites(); + } + + // MARK: - Sleep + startSleep() { + if (this.#stopCalled) { + this.#rLog.debug({ + msg: "cannot call startSleep if actor already stopping", + }); + return; + } + + if (this.#sleepCalled) { + this.#rLog.warn({ + msg: "cannot call startSleep twice, actor already sleeping", + }); + return; + } + this.#sleepCalled = true; + + const sleep = this.#actorDriver.startSleep?.bind( + this.#actorDriver, + this.#actorId, + ); + invariant(this.#sleepingSupported, "sleeping not supported"); + invariant(sleep, "no sleep on driver"); + + this.#rLog.info({ msg: "actor sleeping" }); + + setImmediate(() => { + sleep(); + }); + } + + // MARK: - HTTP Request Tracking + beginHonoHttpRequest() { + this.#activeHonoHttpRequests++; + this.#resetSleepTimer(); + } + + endHonoHttpRequest() { + this.#activeHonoHttpRequests--; + if (this.#activeHonoHttpRequests < 0) { + this.#activeHonoHttpRequests = 0; + this.#rLog.warn({ + msg: "active hono requests went below 0, this is a RivetKit bug", + ...EXTRA_ERROR_LOG, + }); + } + this.#resetSleepTimer(); + } + + // MARK: - State Management + async saveState(opts: SaveStateOptions) { + this.#assertReady(opts.allowStoppingState); + + // Save state through StateManager + await this.#stateManager.saveState(opts); + + // Save connection changes + if (this.#connectionManager.changedConnections.size > 0) { + const entries = this.#connectionManager.getChangedConnectionsData(); + if (entries.length > 0) { + await this.#actorDriver.kvBatchPut(this.#actorId, entries); + } + this.#connectionManager.clearChangedConnections(); + } + } + + // MARK: - Connection Management + getConnForId(id: string): Conn | undefined { + return this.#connectionManager.getConnForId(id); + } + + markConnChanged(conn: Conn) { + this.#connectionManager.markConnChanged(conn); + } + + connDisconnected(conn: Conn, wasClean: boolean) { + this.#connectionManager.connDisconnected( + conn, + wasClean, + this.#actorDriver, + this.#eventManager, + ); + this.#resetSleepTimer(); + } + + async createConn( + driver: ConnDriver, + params: any, + request?: Request, + ): Promise> { + this.#assertReady(); + + const conn = await this.#connectionManager.createConn( + driver, + params, + request, + ); + + // Reset sleep timer after connection + this.#resetSleepTimer(); + + // Save state immediately + await this.saveState({ immediate: true }); + + // Send init message + conn.sendMessage( + new CachedSerializer( + { + body: { + tag: "Init", + val: { + actorId: this.id, + connectionId: conn.id, + }, + }, + }, + TO_CLIENT_VERSIONED, + ), + ); + + return conn; + } + + // MARK: - Message Processing + async processMessage( + message: protocol.ToServer, + conn: Conn, + ) { + await processMessage(message, this, conn, { + onExecuteAction: async (ctx, name, args) => { + this.inspector.emitter.emit("eventFired", { + type: "action", + name, + args, + connId: conn.id, + }); + return await this.executeAction(ctx, name, args); + }, + onSubscribe: async (eventName, conn) => { + this.inspector.emitter.emit("eventFired", { + type: "subscribe", + eventName, + connId: conn.id, + }); + this.#eventManager.addSubscription(eventName, conn, false); + }, + onUnsubscribe: async (eventName, conn) => { + this.inspector.emitter.emit("eventFired", { + type: "unsubscribe", + eventName, + connId: conn.id, + }); + this.#eventManager.removeSubscription(eventName, conn, false); + }, + }); + } + + // MARK: - Action Execution + async executeAction( + ctx: ActionContext, + actionName: string, + args: unknown[], + ): Promise { + invariant(this.#ready, "executing action before ready"); + + if (!(actionName in this.#config.actions)) { + this.#rLog.warn({ msg: "action does not exist", actionName }); + throw new errors.ActionNotFound(actionName); + } + + const actionFunction = this.#config.actions[actionName]; + if (typeof actionFunction !== "function") { + this.#rLog.warn({ + msg: "action is not a function", + actionName, + type: typeof actionFunction, + }); + throw new errors.ActionNotFound(actionName); + } + + try { + this.#rLog.debug({ + msg: "executing action", + actionName, + args, + }); + + const outputOrPromise = actionFunction.call( + undefined, + ctx, + ...args, + ); + + let output: unknown; + if (outputOrPromise instanceof Promise) { + output = await deadline( + outputOrPromise, + this.#config.options.actionTimeout, + ); + } else { + output = outputOrPromise; + } + + // Process through onBeforeActionResponse if configured + if (this.#config.onBeforeActionResponse) { + try { + const processedOutput = this.#config.onBeforeActionResponse( + this.actorContext, + actionName, + args, + output, + ); + if (processedOutput instanceof Promise) { + output = await processedOutput; + } else { + output = processedOutput; + } + } catch (error) { + this.#rLog.error({ + msg: "error in `onBeforeActionResponse`", + error: stringifyError(error), + }); + } + } + + return output; + } catch (error) { + if (error instanceof DeadlineError) { + throw new errors.ActionTimedOut(); + } + this.#rLog.error({ + msg: "action error", + actionName, + error: stringifyError(error), + }); + throw error; + } finally { + this.#stateManager.savePersistThrottled(); + } + } + + // MARK: - HTTP/WebSocket Handlers + async handleFetch( + request: Request, + opts: Record, + ): Promise { + this.#assertReady(); + + if (!this.#config.onFetch) { + throw new errors.FetchHandlerNotDefined(); + } + + try { + const response = await this.#config.onFetch( + this.actorContext, + request, + opts, + ); + if (!response) { + throw new errors.InvalidFetchResponse(); + } + return response; + } catch (error) { + this.#rLog.error({ + msg: "onFetch error", + error: stringifyError(error), + }); + throw error; + } finally { + this.#stateManager.savePersistThrottled(); + } + } + + async handleWebSocket( + websocket: UniversalWebSocket, + opts: { request: Request }, + ): Promise { + this.#assertReady(); + + if (!this.#config.onWebSocket) { + throw new errors.InternalError("onWebSocket handler not defined"); + } + + try { + const stateBeforeHandler = this.#stateManager.persistChanged; + + // Track active websocket + this.#activeRawWebSockets.add(websocket); + this.#resetSleepTimer(); + + // Setup WebSocket event handlers (simplified for brevity) + this.#setupWebSocketHandlers(websocket); + + // Handle WebSocket + await this.#config.onWebSocket(this.actorContext, websocket, opts); + + // Save state if changed + if (this.#stateManager.persistChanged && !stateBeforeHandler) { + await this.saveState({ immediate: true }); + } + } catch (error) { + this.#rLog.error({ + msg: "onWebSocket error", + error: stringifyError(error), + }); + throw error; + } finally { + this.#stateManager.savePersistThrottled(); + } + } + + // MARK: - Event Broadcasting + broadcast>(name: string, ...args: Args) { + this.#assertReady(); + this.#eventManager.broadcast(name, ...args); + } + + // MARK: - Scheduling + async scheduleEvent( + timestamp: number, + action: string, + args: unknown[], + ): Promise { + await this.#scheduleManager.scheduleEvent(timestamp, action, args); + } + + async onAlarm() { + this.#resetSleepTimer(); + await this.#scheduleManager.onAlarm(); + } + + // MARK: - Background Tasks + waitUntil(promise: Promise) { + this.#assertReady(); + + const nonfailablePromise = promise + .then(() => { + this.#rLog.debug({ msg: "wait until promise complete" }); + }) + .catch((error) => { + this.#rLog.error({ + msg: "wait until promise failed", + error: stringifyError(error), + }); + }); + this.#backgroundPromises.push(nonfailablePromise); + } + + // MARK: - Private Helper Methods + #initializeLogging() { + const logParams = { + actor: this.#name, + key: serializeActorKey(this.#key), + actorId: this.#actorId, + }; + + const extraLogParams = this.#actorDriver.getExtraActorLogParams?.(); + if (extraLogParams) Object.assign(logParams, extraLogParams); + + this.#log = getBaseLogger().child( + Object.assign( + getIncludeTarget() ? { target: "actor" } : {}, + logParams, + ), + ); + this.#rLog = getBaseLogger().child( + Object.assign( + getIncludeTarget() ? { target: "actor-runtime" } : {}, + logParams, + ), + ); + } + + async #initializeState() { + // Read initial state from KV + const [persistDataBuffer] = await this.#actorDriver.kvBatchGet( + this.#actorId, + [KEYS.PERSIST_DATA], + ); + invariant( + persistDataBuffer !== null, + "persist data has not been set, it should be set when initialized", + ); + + const bareData = + ACTOR_VERSIONED.deserializeWithEmbeddedVersion(persistDataBuffer); + const persistData = + this.#stateManager.convertFromBarePersisted(bareData); + + if (persistData.hasInitialized) { + // Restore existing actor + await this.#restoreExistingActor(persistData); + } else { + // Create new actor + await this.#createNewActor(persistData); + } + + // Pass persist reference to schedule manager + this.#scheduleManager.setPersist(this.#stateManager.persist); + } + + async #restoreExistingActor(persistData: PersistedActor) { + // List all connection keys + const connEntries = await this.#actorDriver.kvListPrefix( + this.#actorId, + KEYS.CONN_PREFIX, + ); + + // Decode connections + const connections: PersistedConn[] = []; + for (const [_key, value] of connEntries) { + try { + const conn = cbor.decode(value) as PersistedConn; + connections.push(conn); + } catch (error) { + this.#rLog.error({ + msg: "failed to decode connection", + error: stringifyError(error), + }); + } + } + + this.#rLog.info({ + msg: "actor restoring", + connections: connections.length, + hibernatableWebSockets: persistData.hibernatableConns.length, + }); + + // Initialize state + this.#stateManager.initPersistProxy(persistData); + + // Restore connections + this.#connectionManager.restoreConnections( + connections, + this.#eventManager, + ); + } + + async #createNewActor(persistData: PersistedActor) { + this.#rLog.info({ msg: "actor creating" }); + + // Initialize state + await this.#stateManager.initializeState(persistData); + + // Call onCreate lifecycle + if (this.#config.onCreate) { + await this.#config.onCreate(this.actorContext, persistData.input!); + } + } + + async #initializeVars() { + let vars: V | undefined; + if ("createVars" in this.#config) { + const dataOrPromise = this.#config.createVars( + this.actorContext as unknown as InitContext, + this.#actorDriver.getContext(this.#actorId), + ); + if (dataOrPromise instanceof Promise) { + vars = await deadline( + dataOrPromise, + this.#config.options.createVarsTimeout, + ); + } else { + vars = dataOrPromise; + } + } else if ("vars" in this.#config) { + vars = structuredClone(this.#config.vars); + } else { + throw new Error( + "Could not create variables from 'createVars' or 'vars'", + ); + } + this.#vars = vars; + } + + async #callOnStart() { + this.#rLog.info({ msg: "actor starting" }); + if (this.#config.onStart) { + const result = this.#config.onStart(this.actorContext); + if (result instanceof Promise) { + await result; + } + } + } + + async #callOnStop() { + if (this.#config.onStop) { + try { + this.#rLog.debug({ msg: "calling onStop" }); + const result = this.#config.onStop(this.actorContext); + if (result instanceof Promise) { + await deadline(result, this.#config.options.onStopTimeout); + } + this.#rLog.debug({ msg: "onStop completed" }); + } catch (error) { + if (error instanceof DeadlineError) { + this.#rLog.error({ msg: "onStop timed out" }); + } else { + this.#rLog.error({ + msg: "error in onStop", + error: stringifyError(error), + }); + } + } + } + } + + async #setupDatabase() { + if ("db" in this.#config && this.#config.db) { + const client = await this.#config.db.createClient({ + getDatabase: () => this.#actorDriver.getDatabase(this.#actorId), + }); + this.#rLog.info({ msg: "database migration starting" }); + await this.#config.db.onMigrate?.(client); + this.#rLog.info({ msg: "database migration complete" }); + this.#db = client; + } + } + + async #disconnectConnections() { + const promises: Promise[] = []; + for (const connection of this.#connectionManager.connections.values()) { + if (!connection.isHibernatable) { + this.#rLog.debug({ + msg: "disconnecting non-hibernatable connection on actor stop", + connId: connection.id, + }); + promises.push(connection.disconnect()); + } + } + + // Wait with timeout + const res = await Promise.race([ + Promise.all(promises).then(() => false), + new Promise((res) => + globalThis.setTimeout(() => res(true), 1500), + ), + ]); + + if (res) { + this.#rLog.warn({ + msg: "timed out waiting for connections to close, shutting down anyway", + }); + } + } + + async #waitBackgroundPromises(timeoutMs: number) { + const pending = this.#backgroundPromises; + if (pending.length === 0) { + this.#rLog.debug({ msg: "no background promises" }); + return; + } + + const timedOut = await Promise.race([ + Promise.allSettled(pending).then(() => false), + new Promise((resolve) => + setTimeout(() => resolve(true), timeoutMs), + ), + ]); + + if (timedOut) { + this.#rLog.error({ + msg: "timed out waiting for background tasks", + count: pending.length, + timeoutMs, + }); + } else { + this.#rLog.debug({ msg: "background promises finished" }); + } + } + + #setupWebSocketHandlers(websocket: UniversalWebSocket) { + // Simplified WebSocket handler setup + // Full implementation would track hibernatable websockets + const onSocketClosed = () => { + this.#activeRawWebSockets.delete(websocket); + this.#resetSleepTimer(); + }; + + websocket.addEventListener("close", onSocketClosed); + websocket.addEventListener("error", onSocketClosed); + } + + #resetSleepTimer() { + if (this.#config.options.noSleep || !this.#sleepingSupported) return; + if (this.#stopCalled) return; + + const canSleep = this.#canSleep(); + + this.#rLog.debug({ + msg: "resetting sleep timer", + canSleep: CanSleep[canSleep], + existingTimeout: !!this.#sleepTimeout, + timeout: this.#config.options.sleepTimeout, + }); + + if (this.#sleepTimeout) { + clearTimeout(this.#sleepTimeout); + this.#sleepTimeout = undefined; + } + + if (this.#sleepCalled) return; + + if (canSleep === CanSleep.Yes) { + this.#sleepTimeout = setTimeout(() => { + this.startSleep(); + }, this.#config.options.sleepTimeout); + } + } + + #canSleep(): CanSleep { + if (!this.#ready) return CanSleep.NotReady; + if (this.#activeHonoHttpRequests > 0) + return CanSleep.ActiveHonoHttpRequests; + if (this.#activeRawWebSockets.size > 0) + return CanSleep.ActiveRawWebSockets; + + for (const _conn of this.#connectionManager.connections.values()) { + return CanSleep.ActiveConns; + } + + return CanSleep.Yes; + } + + get #sleepingSupported(): boolean { + return this.#actorDriver.startSleep !== undefined; + } + + get #varsEnabled(): boolean { + return "createVars" in this.#config || "vars" in this.#config; + } + + #validateVarsEnabled() { + if (!this.#varsEnabled) { + throw new errors.VarsNotEnabled(); + } + } +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/persisted.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/persisted.ts similarity index 100% rename from rivetkit-typescript/packages/rivetkit/src/actor/persisted.ts rename to rivetkit-typescript/packages/rivetkit/src/actor/instance/persisted.ts diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/schedule-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/schedule-manager.ts new file mode 100644 index 0000000000..2291b23f07 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/schedule-manager.ts @@ -0,0 +1,349 @@ +import * as cbor from "cbor-x"; +import { + bufferToArrayBuffer, + SinglePromiseQueue, + stringifyError, +} from "@/utils"; +import type { AnyDatabaseProvider } from "../database"; +import type { ActorDriver } from "../driver"; +import type { ActorInstance } from "./mod"; +import type { PersistedScheduleEvent } from "./persisted"; + +/** + * Manages scheduled events and alarms for actor instances. + * Handles event scheduling, alarm triggers, and automatic event execution. + */ +export class ScheduleManager { + #actor: ActorInstance; + #actorDriver: ActorDriver; + #alarmWriteQueue = new SinglePromiseQueue(); + #config: any; // ActorConfig type + #persist: any; // Reference to PersistedActor + + constructor( + actor: ActorInstance, + actorDriver: ActorDriver, + config: any, + ) { + this.#actor = actor; + this.#actorDriver = actorDriver; + this.#config = config; + } + + // MARK: - Public API + + /** + * Sets the persist object reference. + * Called after StateManager initializes the persist proxy. + */ + setPersist(persist: any) { + this.#persist = persist; + } + + /** + * Schedules an event to be executed at a specific timestamp. + * + * @param timestamp - Unix timestamp in milliseconds when the event should fire + * @param action - The name of the action to execute + * @param args - Arguments to pass to the action + */ + async scheduleEvent( + timestamp: number, + action: string, + args: unknown[], + ): Promise { + const newEvent: PersistedScheduleEvent = { + eventId: crypto.randomUUID(), + timestamp, + action, + args: bufferToArrayBuffer(cbor.encode(args)), + }; + + await this.#scheduleEventInner(newEvent); + } + + /** + * Triggers any pending alarms that are due. + * This method is idempotent and safe to call multiple times. + */ + async onAlarm(): Promise { + const now = Date.now(); + this.#actor.log.debug({ + msg: "alarm triggered", + now, + events: this.#persist?.scheduledEvents?.length || 0, + }); + + if (!this.#persist?.scheduledEvents) { + this.#actor.rLog.debug({ msg: "no scheduled events" }); + return; + } + + // Find events that are due + const dueIndex = this.#persist.scheduledEvents.findIndex( + (x: PersistedScheduleEvent) => x.timestamp <= now, + ); + + if (dueIndex === -1) { + // No events are due yet + this.#actor.rLog.debug({ msg: "no events are due yet" }); + + // Reschedule alarm for next event if any exist + if (this.#persist.scheduledEvents.length > 0) { + const nextTs = this.#persist.scheduledEvents[0].timestamp; + this.#actor.log.debug({ + msg: "alarm fired early, rescheduling for next event", + now, + nextTs, + delta: nextTs - now, + }); + await this.#queueSetAlarm(nextTs); + } + return; + } + + // Remove and process due events + const dueEvents = this.#persist.scheduledEvents.splice(0, dueIndex + 1); + this.#actor.log.debug({ + msg: "running events", + count: dueEvents.length, + }); + + // Schedule next alarm if more events remain + if (this.#persist.scheduledEvents.length > 0) { + const nextTs = this.#persist.scheduledEvents[0].timestamp; + this.#actor.log.info({ + msg: "setting next alarm", + nextTs, + remainingEvents: this.#persist.scheduledEvents.length, + }); + await this.#queueSetAlarm(nextTs); + } + + // Execute due events + await this.#executeDueEvents(dueEvents); + } + + /** + * Initializes alarms on actor startup. + * Sets the alarm for the next scheduled event if any exist. + */ + async initializeAlarms(): Promise { + if (this.#persist?.scheduledEvents?.length > 0) { + await this.#queueSetAlarm( + this.#persist.scheduledEvents[0].timestamp, + ); + } + } + + /** + * Waits for any pending alarm write operations to complete. + */ + async waitForPendingAlarmWrites(): Promise { + if (this.#alarmWriteQueue.runningDrainLoop) { + await this.#alarmWriteQueue.runningDrainLoop; + } + } + + /** + * Gets statistics about scheduled events. + */ + getScheduleStats(): { + totalEvents: number; + nextEventTime: number | null; + overdueCount: number; + } { + if (!this.#persist?.scheduledEvents) { + return { + totalEvents: 0, + nextEventTime: null, + overdueCount: 0, + }; + } + + const now = Date.now(); + const events = this.#persist.scheduledEvents; + + return { + totalEvents: events.length, + nextEventTime: events.length > 0 ? events[0].timestamp : null, + overdueCount: events.filter( + (e: PersistedScheduleEvent) => e.timestamp <= now, + ).length, + }; + } + + /** + * Cancels a scheduled event by its ID. + * + * @param eventId - The ID of the event to cancel + * @returns True if the event was found and cancelled + */ + async cancelEvent(eventId: string): Promise { + if (!this.#persist?.scheduledEvents) { + return false; + } + + const index = this.#persist.scheduledEvents.findIndex( + (e: PersistedScheduleEvent) => e.eventId === eventId, + ); + + if (index === -1) { + return false; + } + + // Remove the event + const wasFirst = index === 0; + this.#persist.scheduledEvents.splice(index, 1); + + // If we removed the first event, update the alarm + if (wasFirst && this.#persist.scheduledEvents.length > 0) { + await this.#queueSetAlarm( + this.#persist.scheduledEvents[0].timestamp, + ); + } + + this.#actor.log.info({ + msg: "cancelled scheduled event", + eventId, + remainingEvents: this.#persist.scheduledEvents.length, + }); + + return true; + } + + // MARK: - Private Helpers + + async #scheduleEventInner(newEvent: PersistedScheduleEvent): Promise { + this.#actor.log.info({ + msg: "scheduling event", + eventId: newEvent.eventId, + timestamp: newEvent.timestamp, + action: newEvent.action, + }); + + if (!this.#persist?.scheduledEvents) { + throw new Error("Persist not initialized"); + } + + // Find insertion point (events are sorted by timestamp) + const insertIndex = this.#persist.scheduledEvents.findIndex( + (x: PersistedScheduleEvent) => x.timestamp > newEvent.timestamp, + ); + + if (insertIndex === -1) { + // Add to end + this.#persist.scheduledEvents.push(newEvent); + } else { + // Insert at correct position + this.#persist.scheduledEvents.splice(insertIndex, 0, newEvent); + } + + // Update alarm if this is the newest event + if (insertIndex === 0 || this.#persist.scheduledEvents.length === 1) { + this.#actor.log.info({ + msg: "setting alarm for new event", + timestamp: newEvent.timestamp, + eventCount: this.#persist.scheduledEvents.length, + }); + await this.#queueSetAlarm(newEvent.timestamp); + } + } + + async #executeDueEvents(events: PersistedScheduleEvent[]): Promise { + for (const event of events) { + try { + this.#actor.log.info({ + msg: "executing scheduled event", + eventId: event.eventId, + timestamp: event.timestamp, + action: event.action, + }); + + // Look up the action function + const fn = this.#config.actions[event.action]; + + if (!fn) { + throw new Error( + `Missing action for scheduled event: ${event.action}`, + ); + } + + if (typeof fn !== "function") { + throw new Error( + `Scheduled event action ${event.action} is not a function (got ${typeof fn})`, + ); + } + + // Decode arguments and execute + const args = event.args + ? cbor.decode(new Uint8Array(event.args)) + : []; + + const result = fn.call( + undefined, + this.#actor.actorContext, + ...args, + ); + + // Handle async actions + if (result instanceof Promise) { + await result; + } + + this.#actor.log.debug({ + msg: "scheduled event completed", + eventId: event.eventId, + action: event.action, + }); + } catch (error) { + this.#actor.log.error({ + msg: "error executing scheduled event", + error: stringifyError(error), + eventId: event.eventId, + timestamp: event.timestamp, + action: event.action, + }); + + // Continue processing other events even if one fails + } + } + } + + async #queueSetAlarm(timestamp: number): Promise { + await this.#alarmWriteQueue.enqueue(async () => { + await this.#actorDriver.setAlarm(this.#actor, timestamp); + }); + } + + /** + * Gets the next scheduled event, if any. + */ + getNextEvent(): PersistedScheduleEvent | null { + if ( + !this.#persist?.scheduledEvents || + this.#persist.scheduledEvents.length === 0 + ) { + return null; + } + return this.#persist.scheduledEvents[0]; + } + + /** + * Gets all scheduled events. + */ + getAllEvents(): PersistedScheduleEvent[] { + return this.#persist?.scheduledEvents || []; + } + + /** + * Clears all scheduled events. + * Use with caution - this removes all pending scheduled events. + */ + clearAllEvents(): void { + if (this.#persist?.scheduledEvents) { + this.#persist.scheduledEvents = []; + this.#actor.log.warn({ msg: "cleared all scheduled events" }); + } + } +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/state-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/state-manager.ts new file mode 100644 index 0000000000..91fcdf9c83 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/state-manager.ts @@ -0,0 +1,440 @@ +import * as cbor from "cbor-x"; +import onChange from "on-change"; +import { isCborSerializable, stringifyError } from "@/common/utils"; +import type * as persistSchema from "@/schemas/actor-persist/mod"; +import { ACTOR_VERSIONED } from "@/schemas/actor-persist/versioned"; +import { + bufferToArrayBuffer, + promiseWithResolvers, + SinglePromiseQueue, +} from "@/utils"; +import type { ActorDriver } from "../driver"; +import * as errors from "../errors"; +import { isConnStatePath, isStatePath } from "../utils"; +import { KEYS } from "./kv"; +import type { ActorInstance } from "./mod"; +import type { PersistedActor } from "./persisted"; + +export interface SaveStateOptions { + /** + * Forces the state to be saved immediately. This function will return when the state has saved successfully. + */ + immediate?: boolean; + /** Bypass ready check for stopping. */ + allowStoppingState?: boolean; +} + +/** + * Manages actor state persistence, proxying, and synchronization. + * Handles automatic state change detection and throttled persistence to KV storage. + */ +export class StateManager { + #actor: ActorInstance; + #actorDriver: ActorDriver; + + // State tracking + #persist!: PersistedActor; + #persistRaw!: PersistedActor; + #persistChanged = false; + #isInOnStateChange = false; + + // Save management + #persistWriteQueue = new SinglePromiseQueue(); + #lastSaveTime = 0; + #pendingSaveTimeout?: NodeJS.Timeout; + #onPersistSavedPromise?: ReturnType>; + + // Configuration + #config: any; // ActorConfig type + #stateSaveInterval: number; + + constructor( + actor: ActorInstance, + actorDriver: ActorDriver, + config: any, + ) { + this.#actor = actor; + this.#actorDriver = actorDriver; + this.#config = config; + this.#stateSaveInterval = config.options.stateSaveInterval || 100; + } + + // MARK: - Public API + + get persist(): PersistedActor { + return this.#persist; + } + + get persistRaw(): PersistedActor { + return this.#persistRaw; + } + + get persistChanged(): boolean { + return this.#persistChanged; + } + + get state(): S { + this.#validateStateEnabled(); + return this.#persist.state; + } + + set state(value: S) { + this.#validateStateEnabled(); + this.#persist.state = value; + } + + get stateEnabled(): boolean { + return "createState" in this.#config || "state" in this.#config; + } + + // MARK: - Initialization + + /** + * Initializes state from persisted data or creates new state. + */ + async initializeState( + persistData: PersistedActor, + ): Promise { + if (!persistData.hasInitialized) { + // Create initial state + let stateData: unknown; + if (this.stateEnabled) { + this.#actor.rLog.info({ msg: "actor state initializing" }); + + if ("createState" in this.#config) { + stateData = await this.#config.createState( + this.#actor.actorContext, + persistData.input!, + ); + } else if ("state" in this.#config) { + stateData = structuredClone(this.#config.state); + } else { + throw new Error( + "Both 'createState' or 'state' were not defined", + ); + } + } else { + this.#actor.rLog.debug({ msg: "state not enabled" }); + } + + // Update persisted data + persistData.state = stateData as S; + persistData.hasInitialized = true; + + // Save initial state + await this.#writePersistedDataDirect(persistData); + } + + // Initialize proxy + this.initPersistProxy(persistData); + } + + /** + * Creates proxy for persist object that handles automatic state change detection. + */ + initPersistProxy(target: PersistedActor) { + // Set raw persist object + this.#persistRaw = target; + + // Validate serializability + if (target === null || typeof target !== "object") { + let invalidPath = ""; + if ( + !isCborSerializable( + target, + (path) => { + invalidPath = path; + }, + "", + ) + ) { + throw new errors.InvalidStateType({ path: invalidPath }); + } + return target; + } + + // Unsubscribe from old state + if (this.#persist) { + onChange.unsubscribe(this.#persist); + } + + // Listen for changes to automatically write state + this.#persist = onChange( + target, + ( + path: string, + value: any, + _previousValue: any, + _applyData: any, + ) => { + this.#handleStateChange(path, value); + }, + { ignoreDetached: true }, + ); + } + + // MARK: - State Persistence + + /** + * Forces the state to get saved. + */ + async saveState(opts: SaveStateOptions): Promise { + this.#actor.rLog.debug({ + msg: "saveState called", + persistChanged: this.#persistChanged, + allowStoppingState: opts.allowStoppingState, + immediate: opts.immediate, + }); + + if (this.#persistChanged) { + if (opts.immediate) { + await this.#savePersistInner(); + } else { + // Create promise for waiting + if (!this.#onPersistSavedPromise) { + this.#onPersistSavedPromise = promiseWithResolvers(); + } + + // Save throttled + this.savePersistThrottled(); + + // Wait for save + await this.#onPersistSavedPromise.promise; + } + } + } + + /** + * Throttled save state method. Used to write to KV at a reasonable cadence. + */ + savePersistThrottled() { + const now = Date.now(); + const timeSinceLastSave = now - this.#lastSaveTime; + + if (timeSinceLastSave < this.#stateSaveInterval) { + // Schedule next save if not already scheduled + if (this.#pendingSaveTimeout === undefined) { + this.#pendingSaveTimeout = setTimeout(() => { + this.#pendingSaveTimeout = undefined; + this.#savePersistInner(); + }, this.#stateSaveInterval - timeSinceLastSave); + } + } else { + // Save immediately + this.#savePersistInner(); + } + } + + /** + * Clears any pending save timeout. + */ + clearPendingSaveTimeout() { + if (this.#pendingSaveTimeout) { + clearTimeout(this.#pendingSaveTimeout); + this.#pendingSaveTimeout = undefined; + } + } + + /** + * Waits for any pending write operations to complete. + */ + async waitForPendingWrites(): Promise { + if (this.#persistWriteQueue.runningDrainLoop) { + await this.#persistWriteQueue.runningDrainLoop; + } + } + + /** + * Gets persistence data entries if state has changed. + */ + getPersistedDataIfChanged(): [Uint8Array, Uint8Array] | null { + if (!this.#persistChanged) return null; + + this.#persistChanged = false; + + const bareData = this.convertToBarePersisted(this.#persistRaw); + return [ + KEYS.PERSIST_DATA, + ACTOR_VERSIONED.serializeWithEmbeddedVersion(bareData), + ]; + } + + // MARK: - BARE Conversion + + convertToBarePersisted( + persist: PersistedActor, + ): persistSchema.Actor { + const hibernatableConns: persistSchema.HibernatableConn[] = + persist.hibernatableConns.map((conn) => ({ + id: conn.id, + parameters: bufferToArrayBuffer( + cbor.encode(conn.parameters || {}), + ), + state: bufferToArrayBuffer(cbor.encode(conn.state || {})), + subscriptions: conn.subscriptions.map((sub) => ({ + eventName: sub.eventName, + })), + hibernatableRequestId: conn.hibernatableRequestId, + lastSeenTimestamp: BigInt(conn.lastSeenTimestamp), + msgIndex: BigInt(conn.msgIndex), + })); + + return { + input: + persist.input !== undefined + ? bufferToArrayBuffer(cbor.encode(persist.input)) + : null, + hasInitialized: persist.hasInitialized, + state: bufferToArrayBuffer(cbor.encode(persist.state)), + hibernatableConns, + scheduledEvents: persist.scheduledEvents.map((event) => ({ + eventId: event.eventId, + timestamp: BigInt(event.timestamp), + action: event.action, + args: event.args ?? null, + })), + }; + } + + convertFromBarePersisted( + bareData: persistSchema.Actor, + ): PersistedActor { + const hibernatableConns = bareData.hibernatableConns.map((conn) => ({ + id: conn.id, + parameters: cbor.decode(new Uint8Array(conn.parameters)), + state: cbor.decode(new Uint8Array(conn.state)), + subscriptions: conn.subscriptions.map((sub) => ({ + eventName: sub.eventName, + })), + hibernatableRequestId: conn.hibernatableRequestId, + lastSeenTimestamp: Number(conn.lastSeenTimestamp), + msgIndex: Number(conn.msgIndex), + })); + + return { + input: bareData.input + ? cbor.decode(new Uint8Array(bareData.input)) + : undefined, + hasInitialized: bareData.hasInitialized, + state: cbor.decode(new Uint8Array(bareData.state)), + hibernatableConns, + scheduledEvents: bareData.scheduledEvents.map((event) => ({ + eventId: event.eventId, + timestamp: Number(event.timestamp), + action: event.action, + args: event.args ?? undefined, + })), + }; + } + + // MARK: - Private Helpers + + #validateStateEnabled() { + if (!this.stateEnabled) { + throw new errors.StateNotEnabled(); + } + } + + #handleStateChange(path: string, value: any) { + const actorStatePath = isStatePath(path); + const connStatePath = isConnStatePath(path); + + // Validate CBOR serializability + if (actorStatePath || connStatePath) { + let invalidPath = ""; + if ( + !isCborSerializable( + value, + (invalidPathPart) => { + invalidPath = invalidPathPart; + }, + "", + ) + ) { + throw new errors.InvalidStateType({ + path: path + (invalidPath ? `.${invalidPath}` : ""), + }); + } + } + + this.#actor.rLog.debug({ + msg: "onChange triggered, setting persistChanged=true", + path, + }); + this.#persistChanged = true; + + // Inform inspector about state changes + if (actorStatePath) { + this.#actor.inspector.emitter.emit( + "stateUpdated", + this.#persist.state, + ); + } + + // Call onStateChange lifecycle hook + if ( + actorStatePath && + this.#config.onStateChange && + this.#actor.isReady() && + !this.#isInOnStateChange + ) { + try { + this.#isInOnStateChange = true; + this.#config.onStateChange( + this.#actor.actorContext, + this.#persistRaw.state, + ); + } catch (error) { + this.#actor.rLog.error({ + msg: "error in `_onStateChange`", + error: stringifyError(error), + }); + } finally { + this.#isInOnStateChange = false; + } + } + } + + async #savePersistInner() { + try { + this.#lastSaveTime = Date.now(); + + if (this.#persistChanged) { + await this.#persistWriteQueue.enqueue(async () => { + this.#actor.rLog.debug({ + msg: "saving persist", + actorChanged: this.#persistChanged, + }); + + const entry = this.getPersistedDataIfChanged(); + if (entry) { + await this.#actorDriver.kvBatchPut(this.#actor.id, [ + entry, + ]); + } + + this.#actor.rLog.debug({ msg: "persist saved" }); + }); + } + + this.#onPersistSavedPromise?.resolve(); + } catch (error) { + this.#actor.rLog.error({ + msg: "error saving persist", + error: stringifyError(error), + }); + this.#onPersistSavedPromise?.reject(error); + throw error; + } + } + + async #writePersistedDataDirect(persistData: PersistedActor) { + const bareData = this.convertToBarePersisted(persistData); + await this.#actorDriver.kvBatchPut(this.#actor.id, [ + [ + KEYS.PERSIST_DATA, + ACTOR_VERSIONED.serializeWithEmbeddedVersion(bareData), + ], + ]); + } +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts index 4f934113a1..4a8c010b74 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts @@ -72,10 +72,10 @@ export type { UniversalWebSocket, } from "@/common/websocket-interface"; export type { ActorKey } from "@/manager/protocol/query"; -export type { ActionContext } from "./action"; export type * from "./config"; -export type { Conn } from "./conn"; -export type { ActorContext } from "./context"; +export type { Conn } from "./conn/mod"; +export type { ActionContext } from "./contexts/action"; +export type { ActorContext } from "./contexts/actor"; export type { ActionContextOf, ActorContextOf, @@ -84,7 +84,7 @@ export type { } from "./definition"; export { lookupInRegistry } from "./definition"; export { UserError, type UserErrorOptions } from "./errors"; -export type { AnyActorInstance } from "./instance"; +export type { AnyActorInstance } from "./instance/mod"; export { type ActorRouter, createActorRouter, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts index 2bb281e8cb..0cd53e94cb 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts @@ -15,9 +15,9 @@ import { } from "@/schemas/client-protocol/versioned"; import { deserializeWithEncoding } from "@/serde"; import { assertUnreachable, bufferToArrayBuffer } from "../../utils"; -import { ActionContext } from "../action"; -import type { Conn } from "../conn"; -import type { ActorInstance } from "../instance"; +import type { Conn } from "../conn/mod"; +import { ActionContext } from "../contexts/action"; +import type { ActorInstance } from "../instance/mod"; interface MessageEventOpts { encoding: Encoding; @@ -160,7 +160,7 @@ export async function processMessage< }); // Send the response back to the client - conn._sendMessage( + conn.sendMessage( new CachedSerializer( { body: { @@ -229,7 +229,7 @@ export async function processMessage< }); // Build response - conn._sendMessage( + conn.sendMessage( new CachedSerializer( { body: { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts index b9d1bd8d4f..b955801972 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts @@ -1,11 +1,13 @@ import * as cbor from "cbor-x"; import type { Context as HonoContext, HonoRequest } from "hono"; import type { WSContext } from "hono/ws"; -import { ActionContext } from "@/actor/action"; -import { type AnyConn, generateConnRequestId } from "@/actor/conn"; -import { ConnDriverKind } from "@/actor/conn-drivers"; +import type { AnyConn } from "@/actor/conn/mod"; +import { ActionContext } from "@/actor/contexts/action"; import * as errors from "@/actor/errors"; -import { type AnyActorInstance, PERSIST_SYMBOL } from "@/actor/instance"; +import { + ACTOR_INSTANCE_PERSIST_SYMBOL, + type AnyActorInstance, +} from "@/actor/instance/mod"; import type { InputData } from "@/actor/protocol/serde"; import { type Encoding, EncodingSchema } from "@/actor/protocol/serde"; import { @@ -22,7 +24,6 @@ import type * as protocol from "@/schemas/client-protocol/mod"; import { HTTP_ACTION_REQUEST_VERSIONED, HTTP_ACTION_RESPONSE_VERSIONED, - TO_SERVER_VERSIONED, } from "@/schemas/client-protocol/versioned"; import { contentTypeForEncoding, @@ -34,6 +35,8 @@ import { bufferToArrayBuffer, promiseWithResolvers, } from "@/utils"; +import { createHttpSocket } from "./conn/drivers/http"; +import { createWebSocketSocket } from "./conn/drivers/websocket"; import type { ActorDriver } from "./driver"; import { loggerWithoutContext } from "./log"; import { parseMessage } from "./protocol/old"; @@ -136,7 +139,7 @@ export async function handleWebSocketConnect( } // Promise used to wait for the websocket close in `disconnect` - const closePromise = promiseWithResolvers(); + const closePromiseResolvers = promiseWithResolvers(); // Track connection outside of scope for cleanup let createdConn: AnyConn | undefined; @@ -158,27 +161,24 @@ export async function handleWebSocketConnect( // Check if this is a hibernatable websocket const isHibernatable = !!requestIdBuf && - actor[PERSIST_SYMBOL].hibernatableConns.findIndex( - (conn) => - arrayBuffersEqual( - conn.hibernatableRequestId, - requestIdBuf, - ), + actor[ + ACTOR_INSTANCE_PERSIST_SYMBOL + ].hibernatableConns.findIndex((conn) => + arrayBuffersEqual( + conn.hibernatableRequestId, + requestIdBuf, + ), ) !== -1; conn = await actor.createConn( - { - requestId: requestId, - requestIdBuf: requestIdBuf, - hibernatable: isHibernatable, - driverState: { - [ConnDriverKind.WEBSOCKET]: { - encoding, - websocket: ws, - closePromise, - }, - }, - }, + createWebSocketSocket( + requestId, + requestIdBuf, + isHibernatable, + encoding, + ws, + closePromiseResolvers.promise, + ), parameters, req, ); @@ -264,7 +264,7 @@ export async function handleWebSocketConnect( ) => { handlersReject(`WebSocket closed (${event.code}): ${event.reason}`); - closePromise.resolve(); + closePromiseResolvers.resolve(); if (event.wasClean) { actor.rLog.info({ @@ -290,7 +290,7 @@ export async function handleWebSocketConnect( handlersPromise.finally(() => { if (createdConn) { const wasClean = event.wasClean || event.code === 1000; - actor.__connDisconnected(createdConn, wasClean, requestId); + actor.connDisconnected(createdConn, wasClean); } }); }, @@ -331,7 +331,6 @@ export async function handleAction( HTTP_ACTION_REQUEST_VERSIONED, ); const actionArgs = cbor.decode(new Uint8Array(request.args)); - const requestId = generateConnRequestId(); // Invoke the action let actor: AnyActorInstance | undefined; @@ -344,11 +343,7 @@ export async function handleAction( // Create conn conn = await actor.createConn( - { - requestId: requestId, - hibernatable: false, - driverState: { [ConnDriverKind.HTTP]: {} }, - }, + createHttpSocket(), parameters, c.req.raw, ); @@ -359,7 +354,7 @@ export async function handleAction( } finally { if (conn) { // HTTP connections don't have persistent sockets, so no socket ID needed - actor?.__connDisconnected(conn, true, requestId); + actor?.connDisconnected(conn, true); } } @@ -394,7 +389,9 @@ export async function handleRawWebSocketHandler( // Extract rivetRequestId provided by engine runner const rivetRequestId = evt?.rivetRequestId; const isHibernatable = - actor[PERSIST_SYMBOL].hibernatableConns.findIndex((conn) => + actor[ + ACTOR_INSTANCE_PERSIST_SYMBOL + ].hibernatableConns.findIndex((conn) => arrayBuffersEqual( conn.hibernatableRequestId, rivetRequestId, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router.ts index c347e98cfe..305ce03a38 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router.ts @@ -29,8 +29,7 @@ import { } from "@/inspector/actor"; import { isInspectorEnabled, secureInspector } from "@/inspector/utils"; import type { RunnerConfig } from "@/registry/run-config"; -import { generateConnRequestId } from "./conn"; -import { ConnDriverKind } from "./conn-drivers"; +import { CONN_DRIVER_SYMBOL, generateConnRequestId } from "./conn/mod"; import type { ActorDriver } from "./driver"; import { InternalError } from "./errors"; import { loggerWithoutContext } from "./log"; @@ -66,11 +65,11 @@ export function createActorRouter( // Track all HTTP requests to prevent actor from sleeping during active requests router.use("*", async (c, next) => { const actor = await actorDriver.loadActor(c.env.actorId); - actor.__beginHonoHttpRequest(); + actor.beginHonoHttpRequest(); try { await next(); } finally { - actor.__endHonoHttpRequest(); + actor.endHonoHttpRequest(); } }); @@ -94,19 +93,15 @@ export function createActorRouter( } const actor = await actorDriver.loadActor(c.env.actorId); - const conn = actor.__getConnForId(connId); + const conn = actor.getConnForId(connId); if (!conn) { return c.text(`Connection not found: ${connId}`, 404); } // Force close the connection without clean shutdown - const driverState = conn.__driverState; - if (driverState && ConnDriverKind.WEBSOCKET in driverState) { - const ws = driverState[ConnDriverKind.WEBSOCKET].websocket; - - // Force close without sending close frame - (ws.raw as any).terminate(); + if (conn[CONN_DRIVER_SYMBOL]?.terminate) { + conn[CONN_DRIVER_SYMBOL].terminate(actor, conn); } return c.json({ success: true }); diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/schedule.ts b/rivetkit-typescript/packages/rivetkit/src/actor/schedule.ts index 8948e8ccff..b512208b88 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/schedule.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/schedule.ts @@ -1,4 +1,4 @@ -import type { AnyActorInstance } from "./instance"; +import type { AnyActorInstance } from "./instance/mod"; export class Schedule { #actor: AnyActorInstance; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/unstable-react.ts b/rivetkit-typescript/packages/rivetkit/src/actor/unstable-react.ts deleted file mode 100644 index f8ffaf65d4..0000000000 --- a/rivetkit-typescript/packages/rivetkit/src/actor/unstable-react.ts +++ /dev/null @@ -1,110 +0,0 @@ -//// @ts-expect-error we do not have types for this lib -//import { renderToPipeableStream } from "@jogit/tmp-react-server-dom-nodeless"; -//import getStream from "get-stream"; -//import { isValidElement } from "react"; -//import { Actor } from "./actor"; -// -///** -// * A React Server Components (RSC) actor. -// * -// * Supports rendering React elements as action responses. -// * -// * @see [Documentation](https://rivet.dev/docs/client/react) -// * @experimental -// */ -//export class RscActor< -// State = undefined, -// ConnParams = undefined, -// ConnState = undefined, -//> extends Actor { -// /** -// * Updates the RSCs for all connected clients. -// */ -// public _updateRsc() { -// // Broadcast a message to all connected clients, telling them to re-render -// this._broadcast("__rsc"); -// } -// -// /** -// * Overrides default behavior to update the RSCs when the state changes. -// * @private -// * @internal -// */ -// override _onStateChange() { -// this._updateRsc(); -// } -// -// /** -// * Overrides default behavior to render React elements as RSC response. -// * @private -// * @internal -// */ -// protected override _onBeforeActionResponse( -// _name: string, -// _args: unknown[], -// output: Out, -// ): Out { -// if (!isValidElement(output)) { -// return super._onBeforeActionResponse(_name, _args, output); -// } -// -// // The output is a React element, so we need to transform it into a valid rsc message -// const { readable, ...writable } = nodeStreamToWebStream(); -// -// const stream = renderToPipeableStream(output); -// -// stream.pipe(writable); -// -// return getStream(readable) as Out; -// } -//} -// -//function nodeStreamToWebStream() { -// const buffer: Uint8Array[] = []; -// let writer: WritableStreamDefaultWriter | null = null; -// -// const writable = new WritableStream({ -// write(chunk) { -// buffer.push(chunk); -// }, -// close() {}, -// }); -// -// const readable = new ReadableStream({ -// start() {}, -// async pull(controller) { -// if (buffer.length > 0) { -// const chunk = buffer.shift(); // Get the next chunk from the buffer -// if (chunk) { -// controller.enqueue(chunk); // Push it to the readable stream -// } -// } else { -// if (writable.locked) { -// await new Promise((resolve) => setTimeout(resolve, 10)); -// return this.pull?.(controller); -// } -// return controller.close(); -// } -// }, -// cancel() {}, -// }); -// -// return { -// readable, -// on: (str: string, fn: () => void) => { -// if (str === "drain") { -// writer = writable.getWriter(); -// fn(); -// } -// }, -// write(chunk: Uint8Array) { -// writer?.write(chunk); -// }, -// flush() { -// writer?.close(); -// }, -// end() { -// writer?.releaseLock(); -// }, -// }; -//} diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-helpers/mod.ts b/rivetkit-typescript/packages/rivetkit/src/driver-helpers/mod.ts index d6622c2717..7b36b9bc4f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-helpers/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-helpers/mod.ts @@ -1,5 +1,5 @@ export type { ActorDriver } from "@/actor/driver"; -export type { ActorInstance, AnyActorInstance } from "@/actor/instance"; +export type { ActorInstance, AnyActorInstance } from "@/actor/instance/mod"; export { generateRandomString } from "@/actor/utils"; export { ALLOWED_PUBLIC_HEADERS, diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts index 9bfdc507fb..74b3544edc 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -10,9 +10,9 @@ import { streamSSE } from "hono/streaming"; import { WSContext } from "hono/ws"; import invariant from "invariant"; import { lookupInRegistry } from "@/actor/definition"; -import { PERSIST_SYMBOL } from "@/actor/instance"; +import { KEYS } from "@/actor/instance/kv"; +import { ACTOR_INSTANCE_PERSIST_SYMBOL } from "@/actor/instance/mod"; import { deserializeActorKey } from "@/actor/keys"; -import { KEYS } from "@/actor/kv"; import { EncodingSchema } from "@/actor/protocol/serde"; import { type ActorRouter, createActorRouter } from "@/actor/router"; import { @@ -170,7 +170,8 @@ export class EngineActorDriver implements ActorDriver { // Check for existing WS const hibernatableArray = - handler.actor[PERSIST_SYMBOL].hibernatableConns; + handler.actor[ACTOR_INSTANCE_PERSIST_SYMBOL] + .hibernatableConns; logger().debug({ msg: "checking hibernatable websockets", requestId: idToStr(requestId), @@ -347,7 +348,7 @@ export class EngineActorDriver implements ActorDriver { // Set alarm const delay = Math.max(0, timestamp - Date.now()); this.#alarmTimeout = setLongTimeout(() => { - actor._onAlarm(); + actor.onAlarm(); this.#alarmTimeout = undefined; }, delay); @@ -358,7 +359,7 @@ export class EngineActorDriver implements ActorDriver { // Instead, it just wakes the actor on the alarm (if not // already awake). // - // _onAlarm is automatically called on `ActorInstance.start` when waking + // onAlarm is automatically called on `ActorInstance.start` when waking // again. this.#runner.setAlarm(actor.id, timestamp); } @@ -486,10 +487,10 @@ export class EngineActorDriver implements ActorDriver { const handler = this.#actors.get(actorId); if (handler?.actor) { try { - await handler.actor._onStop(); + await handler.actor.onStop(); } catch (err) { logger().error({ - msg: "error in _onStop, proceeding with removing actor", + msg: "error in onStop, proceeding with removing actor", err: stringifyError(err), }); } @@ -663,14 +664,14 @@ export class EngineActorDriver implements ActorDriver { // reschedule // // This means that: - // - All actors on this runner are bricked until the slowest _onStop finishes + // - All actors on this runner are bricked until the slowest onStop finishes // - Guard will not gracefully handle requests bc it's not receiving a 503 // - Actors can still be scheduled to this runner while the other - // actors are stopping, meaning that those actors will NOT get _onStop + // actors are stopping, meaning that those actors will NOT get onStop // and will potentiall corrupt their state // // HACK: Stop all actors to allow state to be saved - // NOTE: _onStop is only supposed to be called by the runner, we're + // NOTE: onStop is only supposed to be called by the runner, we're // abusing it here logger().debug({ msg: "stopping all actors before shutdown", @@ -680,9 +681,9 @@ export class EngineActorDriver implements ActorDriver { for (const [_actorId, handler] of this.#actors.entries()) { if (handler.actor) { stopPromises.push( - handler.actor._onStop().catch((err) => { + handler.actor.onStop().catch((err) => { handler.actor?.rLog.error({ - msg: "_onStop errored", + msg: "onStop errored", error: stringifyError(err), }); }), diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/global-state.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/global-state.ts index 016d3a489a..9bc76f87a8 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/global-state.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/global-state.ts @@ -5,7 +5,7 @@ import * as path from "node:path"; import invariant from "invariant"; import { lookupInRegistry } from "@/actor/definition"; import { ActorAlreadyExists } from "@/actor/errors"; -import type { AnyActorInstance } from "@/actor/instance"; +import type { AnyActorInstance } from "@/actor/instance/mod"; import type { ActorKey } from "@/actor/mod"; import { generateRandomString } from "@/actor/utils"; import type { AnyClient } from "@/client/client"; @@ -341,7 +341,7 @@ export class FileSystemGlobalState { // Stop actor invariant(actor.actor, "actor should be loaded"); - await actor.actor._onStop(); + await actor.actor.onStop(); // Remove from map after stop is complete this.#actors.delete(actorId); @@ -672,7 +672,7 @@ export class FileSystemGlobalState { } invariant(loaded.actor, "actor should be loaded after wake"); - await loaded.actor._onAlarm(); + await loaded.actor.onAlarm(); } catch (err) { logger().error({ msg: "failed to handle alarm", diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/manager.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/manager.ts index f5ab47c125..adf5691d71 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/manager.ts @@ -1,6 +1,6 @@ import type { Context as HonoContext } from "hono"; import invariant from "invariant"; -import { generateConnRequestId } from "@/actor/conn"; +import { generateConnRequestId } from "@/actor/conn/mod"; import { type ActorRouter, createActorRouter } from "@/actor/router"; import { handleRawWebSocketHandler, diff --git a/rivetkit-typescript/packages/rivetkit/tests/actor-types.test.ts b/rivetkit-typescript/packages/rivetkit/tests/actor-types.test.ts index 74a58a7983..e6f16d7944 100644 --- a/rivetkit-typescript/packages/rivetkit/tests/actor-types.test.ts +++ b/rivetkit-typescript/packages/rivetkit/tests/actor-types.test.ts @@ -1,5 +1,5 @@ import { describe, expectTypeOf, it } from "vitest"; -import type { ActorContext } from "@/actor/context"; +import type { ActorContext } from "@/actor/contexts/actor"; import type { ActorContextOf, ActorDefinition } from "@/actor/definition"; describe("ActorDefinition", () => {