diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts b/rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts index dbb09e359b..1bace9e7af 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts @@ -6,10 +6,10 @@ import type { } from "rivetkit"; import { lookupInRegistry } from "rivetkit"; import type { Client } from "rivetkit/client"; -import { - type ActorDriver, - type AnyActorInstance, - type ManagerDriver, +import type { + ActorDriver, + AnyActorInstance, + ManagerDriver, } from "rivetkit/driver-helpers"; import { promiseWithResolvers } from "rivetkit/utils"; import { KEYS } from "./actor-handler-do"; @@ -239,7 +239,6 @@ export class CloudflareActorsActorDriver implements ActorDriver { // Persist data key return Uint8Array.from([1]); } - } export function createCloudflareActorsActorDriverBuilder( diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts index c4fdf59eee..25d8395722 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts @@ -1,11 +1,8 @@ import * as cbor from "cbor-x"; -import onChange from "on-change"; -import { isCborSerializable } from "@/common/utils"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils"; import type { AnyDatabaseProvider } from "../database"; -import * as errors from "../errors"; import { ACTOR_INSTANCE_PERSIST_SYMBOL, type ActorInstance, @@ -13,6 +10,7 @@ import { import type { PersistedConn } from "../instance/persisted"; import { CachedSerializer } from "../protocol/serde"; import type { ConnDriver } from "./driver"; +import { StateManager } from "./state-manager"; export function generateConnRequestId(): string { return crypto.randomUUID(); @@ -24,6 +22,12 @@ export type AnyConn = Conn; export const CONN_PERSIST_SYMBOL = Symbol("persist"); export const CONN_DRIVER_SYMBOL = Symbol("driver"); +export const CONN_ACTOR_SYMBOL = Symbol("actor"); +export const CONN_STATE_ENABLED_SYMBOL = Symbol("stateEnabled"); +export const CONN_PERSIST_RAW_SYMBOL = Symbol("persistRaw"); +export const CONN_HAS_CHANGES_SYMBOL = Symbol("hasChanges"); +export const CONN_MARK_SAVED_SYMBOL = Symbol("markSaved"); +export const CONN_SEND_MESSAGE_SYMBOL = Symbol("sendMessage"); /** * Represents a client connection to a actor. @@ -38,34 +42,30 @@ export class Conn { // TODO: Remove this cyclical reference #actor: ActorInstance; - /** - * The proxied state that notifies of changes automatically. - * - * Any data that should be stored indefinitely should be held within this - * object. - * - * This will only be persisted if using hibernatable WebSockets. If not, - * this is just used to hole state. - */ - [CONN_PERSIST_SYMBOL]!: PersistedConn; - - /** Raw persist object without the proxy wrapper */ - #persistRaw: PersistedConn; - - /** Track if this connection's state has changed */ - #changed = false; + // MARK: - Managers + #stateManager!: StateManager; /** * If undefined, then nothing is connected to this. */ [CONN_DRIVER_SYMBOL]?: ConnDriver; - public get params(): CP { - return this[CONN_PERSIST_SYMBOL].params; + // MARK: - Public Getters + + get [CONN_ACTOR_SYMBOL](): ActorInstance { + return this.#actor; } - public get stateEnabled() { - return this.#actor.connStateEnabled; + get [CONN_PERSIST_SYMBOL](): PersistedConn { + return this.#stateManager.persist; + } + + get params(): CP { + return this.#stateManager.params; + } + + get [CONN_STATE_ENABLED_SYMBOL](): boolean { + return this.#stateManager.stateEnabled; } /** @@ -73,11 +73,8 @@ export class Conn { * * Throws an error if the state is not enabled. */ - public get state(): CS { - this.#validateStateEnabled(); - if (!this[CONN_PERSIST_SYMBOL].state) - throw new Error("state should exists"); - return this[CONN_PERSIST_SYMBOL].state; + get state(): CS { + return this.#stateManager.state; } /** @@ -85,16 +82,15 @@ export class Conn { * * Throws an error if the state is not enabled. */ - public set state(value: CS) { - this.#validateStateEnabled(); - this[CONN_PERSIST_SYMBOL].state = value; + set state(value: CS) { + this.#stateManager.state = value; } /** * Unique identifier for the connection. */ - public get id(): ConnId { - return this[CONN_PERSIST_SYMBOL].connId; + get id(): ConnId { + return this.#stateManager.persist.connId; } /** @@ -102,8 +98,10 @@ export class Conn { * * If the underlying connection can hibernate. */ - public get isHibernatable(): boolean { - if (!this[CONN_PERSIST_SYMBOL].hibernatableRequestId) { + get isHibernatable(): boolean { + const hibernatableRequestId = + this.#stateManager.persist.hibernatableRequestId; + if (!hibernatableRequestId) { return false; } return ( @@ -112,7 +110,7 @@ export class Conn { ].hibernatableConns.findIndex((conn: any) => arrayBuffersEqual( conn.hibernatableRequestId, - this[CONN_PERSIST_SYMBOL].hibernatableRequestId!, + hibernatableRequestId, ), ) > -1 ); @@ -121,8 +119,8 @@ export class Conn { /** * Timestamp of the last time the connection was seen, i.e. the last time the connection was active and checked for liveness. */ - public get lastSeen(): number { - return this[CONN_PERSIST_SYMBOL].lastSeen; + get lastSeen(): number { + return this.#stateManager.persist.lastSeen; } /** @@ -132,94 +130,37 @@ export class Conn { * * @protected */ - public constructor( + constructor( actor: ActorInstance, persist: PersistedConn, ) { this.#actor = actor; - this.#persistRaw = persist; - this.#setupPersistProxy(persist); - } - - /** - * Sets up the proxy for connection persistence with change tracking - */ - #setupPersistProxy(persist: PersistedConn) { - // If this can't be proxied, return raw value - if (persist === null || typeof persist !== "object") { - this[CONN_PERSIST_SYMBOL] = persist; - return; - } - - // Listen for changes to the object - this[CONN_PERSIST_SYMBOL] = onChange( - persist, - ( - path: string, - value: any, - _previousValue: any, - _applyData: any, - ) => { - // Validate CBOR serializability for state changes - if (path.startsWith("state")) { - let invalidPath = ""; - if ( - !isCborSerializable( - value, - (invalidPathPart: string) => { - invalidPath = invalidPathPart; - }, - "", - ) - ) { - throw new errors.InvalidStateType({ - path: path + (invalidPath ? `.${invalidPath}` : ""), - }); - } - } - - this.#changed = true; - this.#actor.rLog.debug({ - msg: "conn onChange triggered", - connId: this.id, - path, - }); - - // Notify actor that this connection has changed - this.#actor.markConnChanged(this); - }, - { ignoreDetached: true }, - ); + this.#stateManager = new StateManager(this); + this.#stateManager.initPersistProxy(persist); } /** * Returns whether this connection has unsaved changes */ - get hasChanges(): boolean { - return this.#changed; + [CONN_HAS_CHANGES_SYMBOL](): boolean { + return this.#stateManager.hasChanges(); } /** * Marks changes as saved */ - markSaved() { - this.#changed = false; + [CONN_MARK_SAVED_SYMBOL]() { + this.#stateManager.markSaved(); } /** * Gets the raw persist data for serialization */ - get persistRaw(): PersistedConn { - return this.#persistRaw; - } - - #validateStateEnabled() { - if (!this.stateEnabled) { - throw new errors.ConnStateNotEnabled(); - } + get [CONN_PERSIST_RAW_SYMBOL](): PersistedConn { + return this.#stateManager.persistRaw; } - public sendMessage(message: CachedSerializer) { + [CONN_SEND_MESSAGE_SYMBOL](message: CachedSerializer) { if (this[CONN_DRIVER_SYMBOL]) { const driver = this[CONN_DRIVER_SYMBOL]; if (driver.sendMessage) { @@ -245,14 +186,14 @@ export class Conn { * @param args - The arguments for the event. * @see {@link https://rivet.dev/docs/events|Events Documentation} */ - public send(eventName: string, ...args: unknown[]) { + send(eventName: string, ...args: unknown[]) { this.#actor.inspector.emitter.emit("eventFired", { type: "event", eventName, args, connId: this.id, }); - this.sendMessage( + this[CONN_SEND_MESSAGE_SYMBOL]( new CachedSerializer( { body: { @@ -273,7 +214,7 @@ export class Conn { * * @param reason - The reason for disconnection. */ - public async disconnect(reason?: string) { + async disconnect(reason?: string) { if (this[CONN_DRIVER_SYMBOL]) { const driver = this[CONN_DRIVER_SYMBOL]; if (driver.disconnect) { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/state-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/state-manager.ts new file mode 100644 index 0000000000..c1ef43fbbe --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/state-manager.ts @@ -0,0 +1,139 @@ +import onChange from "on-change"; +import { isCborSerializable } from "@/common/utils"; +import * as errors from "../errors"; +import type { PersistedConn } from "../instance/persisted"; +import { CONN_ACTOR_SYMBOL, CONN_STATE_ENABLED_SYMBOL, type Conn } from "./mod"; + +/** + * Manages connection state persistence, proxying, and change tracking. + * Handles automatic state change detection for connection-specific state. + */ +export class StateManager { + #conn: Conn; + + // State tracking + #persist!: PersistedConn; + #persistRaw!: PersistedConn; + #changed = false; + + constructor(conn: Conn) { + this.#conn = conn; + } + + // MARK: - Public API + + get persist(): PersistedConn { + return this.#persist; + } + + get persistRaw(): PersistedConn { + return this.#persistRaw; + } + + get changed(): boolean { + return this.#changed; + } + + get stateEnabled(): boolean { + return this.#conn[CONN_ACTOR_SYMBOL].connStateEnabled; + } + + get state(): CS { + this.#validateStateEnabled(); + if (!this.#persist.state) throw new Error("state should exists"); + return this.#persist.state; + } + + set state(value: CS) { + this.#validateStateEnabled(); + this.#persist.state = value; + } + + get params(): CP { + return this.#persist.params; + } + + // MARK: - Initialization + + /** + * Creates proxy for persist object that handles automatic state change detection. + */ + initPersistProxy(target: PersistedConn) { + // Set raw persist object + this.#persistRaw = target; + + // If this can't be proxied, return raw value + if (target === null || typeof target !== "object") { + this.#persist = target; + return; + } + + // Listen for changes to the object + this.#persist = onChange( + target, + ( + path: string, + value: any, + _previousValue: any, + _applyData: any, + ) => { + this.#handleChange(path, value); + }, + { ignoreDetached: true }, + ); + } + + // MARK: - Change Management + + /** + * Returns whether this connection has unsaved changes + */ + hasChanges(): boolean { + return this.#changed; + } + + /** + * Marks changes as saved + */ + markSaved() { + this.#changed = false; + } + + // MARK: - Private Helpers + + #validateStateEnabled() { + if (!this.stateEnabled) { + throw new errors.ConnStateNotEnabled(); + } + } + + #handleChange(path: string, value: any) { + // Validate CBOR serializability for state changes + if (path.startsWith("state")) { + let invalidPath = ""; + if ( + !isCborSerializable( + value, + (invalidPathPart: string) => { + invalidPath = invalidPathPart; + }, + "", + ) + ) { + throw new errors.InvalidStateType({ + path: path + (invalidPath ? `.${invalidPath}` : ""), + }); + } + } + + this.#changed = true; + this.#conn[CONN_ACTOR_SYMBOL].rLog.debug({ + msg: "conn onChange triggered", + connId: this.#conn.id, + path, + }); + + // Notify actor that this connection has changed + this.#conn[CONN_ACTOR_SYMBOL].markConnChanged(this.#conn); + } +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts index 23d6481ed0..b888bd039a 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts @@ -4,7 +4,10 @@ import type { OnConnectOptions } from "../config"; import type { ConnDriver } from "../conn/driver"; import { CONN_DRIVER_SYMBOL, + CONN_MARK_SAVED_SYMBOL, + CONN_PERSIST_RAW_SYMBOL, CONN_PERSIST_SYMBOL, + CONN_STATE_ENABLED_SYMBOL, Conn, type ConnId, } from "../conn/mod"; @@ -202,9 +205,9 @@ export class ConnectionManager< for (const connId of this.#changedConnections) { const conn = this.#connections.get(connId); if (conn) { - const connData = cbor.encode(conn.persistRaw); + const connData = cbor.encode(conn[CONN_PERSIST_RAW_SYMBOL]); entries.push([makeConnKey(connId), connData]); - conn.markSaved(); + conn[CONN_MARK_SAVED_SYMBOL](); } } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts index 944617e002..1f67fd6e1a 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts @@ -2,7 +2,11 @@ import * as cbor from "cbor-x"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; import { bufferToArrayBuffer } from "@/utils"; -import { CONN_PERSIST_SYMBOL, type Conn } from "../conn/mod"; +import { + CONN_PERSIST_SYMBOL, + CONN_SEND_MESSAGE_SYMBOL, + type Conn, +} from "../conn/mod"; import type { AnyDatabaseProvider } from "../database"; import { CachedSerializer } from "../protocol/serde"; import type { ActorInstance } from "./mod"; @@ -192,7 +196,7 @@ export class EventManager { let sentCount = 0; for (const connection of subscribers) { try { - connection.sendMessage(toClientSerializer); + connection[CONN_SEND_MESSAGE_SYMBOL](toClientSerializer); sentCount++; } catch (error) { this.#actor.rLog.error({ diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts index d79a4aa93c..e4e45a1c19 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts @@ -14,7 +14,13 @@ import { EXTRA_ERROR_LOG, idToStr } from "@/utils"; import type { ActorConfig, InitContext } from "../config"; import type { ConnDriver } from "../conn/driver"; import { createHttpSocket } from "../conn/drivers/http"; -import { CONN_PERSIST_SYMBOL, type Conn, type ConnId } from "../conn/mod"; +import { + CONN_PERSIST_SYMBOL, + CONN_SEND_MESSAGE_SYMBOL, + CONN_STATE_ENABLED_SYMBOL, + type Conn, + type ConnId, +} from "../conn/mod"; import { ActionContext } from "../contexts/action"; import { ActorContext } from "../contexts/actor"; import type { AnyDatabaseProvider, InferDatabaseClient } from "../database"; @@ -134,10 +140,12 @@ export class ActorInstance { ).map(([id, conn]) => ({ id, params: conn.params as any, - state: conn.stateEnabled ? conn.state : undefined, + state: conn[CONN_STATE_ENABLED_SYMBOL] + ? conn.state + : undefined, subscriptions: conn.subscriptions.size, lastSeen: conn.lastSeen, - stateEnabled: conn.stateEnabled, + stateEnabled: conn[CONN_STATE_ENABLED_SYMBOL], isHibernatable: conn.isHibernatable, hibernatableRequestId: conn[CONN_PERSIST_SYMBOL] .hibernatableRequestId @@ -498,7 +506,7 @@ export class ActorInstance { await this.saveState({ immediate: true }); // Send init message - conn.sendMessage( + conn[CONN_SEND_MESSAGE_SYMBOL]( new CachedSerializer( { body: { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts index 0cd53e94cb..ec41345e14 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts @@ -15,7 +15,7 @@ import { } from "@/schemas/client-protocol/versioned"; import { deserializeWithEncoding } from "@/serde"; import { assertUnreachable, bufferToArrayBuffer } from "../../utils"; -import type { Conn } from "../conn/mod"; +import { CONN_SEND_MESSAGE_SYMBOL, type Conn } from "../conn/mod"; import { ActionContext } from "../contexts/action"; import type { ActorInstance } from "../instance/mod"; @@ -160,7 +160,7 @@ export async function processMessage< }); // Send the response back to the client - conn.sendMessage( + conn[CONN_SEND_MESSAGE_SYMBOL]( new CachedSerializer( { body: { @@ -229,7 +229,7 @@ export async function processMessage< }); // Build response - conn.sendMessage( + conn[CONN_SEND_MESSAGE_SYMBOL]( new CachedSerializer( { body: {