From c57be74dde802a3daaeb4c0c76fab83c66776f73 Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Sat, 8 Nov 2025 19:07:07 -0800 Subject: [PATCH] chore(rivetkit): move conns to separate persisted kv keys --- .../cloudflare-workers/src/actor-driver.ts | 101 +- .../src/actor-handler-do.ts | 6 + .../packages/rivetkit/package.json | 2 +- .../rivetkit/schemas/actor-persist/v3.bare | 15 +- .../schemas/file-system-driver/v2.bare | 23 + .../packages/rivetkit/src/actor/conn.ts | 89 +- .../packages/rivetkit/src/actor/driver.ts | 23 +- .../packages/rivetkit/src/actor/instance.ts | 1610 +++++++++-------- .../packages/rivetkit/src/actor/kv.ts | 14 + .../packages/rivetkit/src/actor/persisted.ts | 58 +- .../rivetkit/src/actor/router-endpoints.ts | 16 +- .../rivetkit/src/driver-helpers/utils.ts | 8 +- .../src/drivers/engine/actor-driver.ts | 115 +- .../rivetkit/src/drivers/engine/kv.ts | 3 - .../rivetkit/src/drivers/file-system/actor.ts | 32 +- .../src/drivers/file-system/global-state.ts | 170 +- .../src/schemas/actor-persist/versioned.ts | 26 +- .../src/schemas/client-protocol/versioned.ts | 8 - .../src/schemas/file-system-driver/mod.ts | 2 +- .../schemas/file-system-driver/versioned.ts | 61 +- 20 files changed, 1416 insertions(+), 966 deletions(-) create mode 100644 rivetkit-typescript/packages/rivetkit/schemas/file-system-driver/v2.bare create mode 100644 rivetkit-typescript/packages/rivetkit/src/actor/kv.ts delete mode 100644 rivetkit-typescript/packages/rivetkit/src/drivers/engine/kv.ts diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts b/rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts index f84b330351..1bace9e7af 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts @@ -139,14 +139,6 @@ export class CloudflareActorsActorDriver implements ActorDriver { return { state: state.ctx }; } - async readPersistedData(actorId: string): Promise { - return await this.#getDOCtx(actorId).storage.get(KEYS.PERSIST_DATA); - } - - async writePersistedData(actorId: string, data: Uint8Array): Promise { - await this.#getDOCtx(actorId).storage.put(KEYS.PERSIST_DATA, data); - } - async setAlarm(actor: AnyActorInstance, timestamp: number): Promise { await this.#getDOCtx(actor.id).storage.setAlarm(timestamp); } @@ -154,6 +146,99 @@ export class CloudflareActorsActorDriver implements ActorDriver { async getDatabase(actorId: string): Promise { return this.#getDOCtx(actorId).storage.sql; } + + // Batch KV operations - convert between Uint8Array and Cloudflare's string-based API + async kvBatchPut( + actorId: string, + entries: [Uint8Array, Uint8Array][], + ): Promise { + const storage = this.#getDOCtx(actorId).storage; + const encoder = new TextDecoder(); + + // Convert Uint8Array entries to object for Cloudflare batch put + const storageObj: Record = {}; + for (const [key, value] of entries) { + // Convert key from Uint8Array to string + const keyStr = this.#uint8ArrayToKey(key); + storageObj[keyStr] = value; + } + + await storage.put(storageObj); + } + + async kvBatchGet( + actorId: string, + keys: Uint8Array[], + ): Promise<(Uint8Array | null)[]> { + const storage = this.#getDOCtx(actorId).storage; + + // Convert keys to strings + const keyStrs = keys.map((k) => this.#uint8ArrayToKey(k)); + + // Get values from storage + const results = await storage.get(keyStrs); + + // Convert Map results to array in same order as input keys + return keyStrs.map((k) => results.get(k) ?? null); + } + + async kvBatchDelete(actorId: string, keys: Uint8Array[]): Promise { + const storage = this.#getDOCtx(actorId).storage; + + // Convert keys to strings + const keyStrs = keys.map((k) => this.#uint8ArrayToKey(k)); + + await storage.delete(keyStrs); + } + + async kvListPrefix( + actorId: string, + prefix: Uint8Array, + ): Promise<[Uint8Array, Uint8Array][]> { + const storage = this.#getDOCtx(actorId).storage; + + // Convert prefix to string + const prefixStr = this.#uint8ArrayToKey(prefix); + + // List with prefix + const results = await storage.list({ prefix: prefixStr }); + + // Convert Map to array of [key, value] tuples + const entries: [Uint8Array, Uint8Array][] = []; + for (const [key, value] of results) { + entries.push([this.#keyToUint8Array(key), value]); + } + + return entries; + } + + // Helper to convert Uint8Array key to string for Cloudflare storage + #uint8ArrayToKey(key: Uint8Array): string { + // Check if this is a connection key (starts with [2]) + if (key.length > 0 && key[0] === 2) { + // Connection key - extract connId + const connId = new TextDecoder().decode(key.slice(1)); + return `${KEYS.CONN_PREFIX}${connId}`; + } + // Otherwise, treat as persist data key [1] + return KEYS.PERSIST_DATA; + } + + // Helper to convert string key back to Uint8Array + #keyToUint8Array(key: string): Uint8Array { + if (key.startsWith(KEYS.CONN_PREFIX)) { + // Connection key + const connId = key.slice(KEYS.CONN_PREFIX.length); + const encoder = new TextEncoder(); + const connIdBytes = encoder.encode(connId); + const result = new Uint8Array(1 + connIdBytes.length); + result[0] = 2; // Connection prefix + result.set(connIdBytes, 1); + return result; + } + // Persist data key + return Uint8Array.from([1]); + } } export function createCloudflareActorsActorDriverBuilder( diff --git a/rivetkit-typescript/packages/cloudflare-workers/src/actor-handler-do.ts b/rivetkit-typescript/packages/cloudflare-workers/src/actor-handler-do.ts index a62989f970..e9cf43997a 100644 --- a/rivetkit-typescript/packages/cloudflare-workers/src/actor-handler-do.ts +++ b/rivetkit-typescript/packages/cloudflare-workers/src/actor-handler-do.ts @@ -20,8 +20,14 @@ export const KEYS = { NAME: "rivetkit:name", KEY: "rivetkit:key", PERSIST_DATA: "rivetkit:data", + CONN_PREFIX: "rivetkit:conn:", }; +// Helper to create a connection key for Cloudflare +export function makeCloudflareConnKey(connId: string): string { + return `${KEYS.CONN_PREFIX}${connId}`; +} + export interface ActorHandlerInterface extends DurableObject { initialize(req: ActorInitRequest): Promise; } diff --git a/rivetkit-typescript/packages/rivetkit/package.json b/rivetkit-typescript/packages/rivetkit/package.json index 79635eb3c5..d9548edc7e 100644 --- a/rivetkit-typescript/packages/rivetkit/package.json +++ b/rivetkit-typescript/packages/rivetkit/package.json @@ -153,7 +153,7 @@ ], "scripts": { "build": "tsup src/mod.ts src/client/mod.ts src/common/log.ts src/common/websocket.ts src/actor/errors.ts src/topologies/coordinate/mod.ts src/topologies/partition/mod.ts src/utils.ts src/driver-helpers/mod.ts src/driver-test-suite/mod.ts src/test/mod.ts src/inspector/mod.ts", - "build:schema": "./scripts/compile-bare.ts compile schemas/client-protocol/v1.bare -o dist/schemas/client-protocol/v1.ts && ./scripts/compile-bare.ts compile schemas/client-protocol/v2.bare -o dist/schemas/client-protocol/v2.ts && ./scripts/compile-bare.ts compile schemas/file-system-driver/v1.bare -o dist/schemas/file-system-driver/v1.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v1.bare -o dist/schemas/actor-persist/v1.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v2.bare -o dist/schemas/actor-persist/v2.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v3.bare -o dist/schemas/actor-persist/v3.ts", + "build:schema": "./scripts/compile-bare.ts compile schemas/client-protocol/v1.bare -o dist/schemas/client-protocol/v1.ts && ./scripts/compile-bare.ts compile schemas/client-protocol/v2.bare -o dist/schemas/client-protocol/v2.ts && ./scripts/compile-bare.ts compile schemas/file-system-driver/v1.bare -o dist/schemas/file-system-driver/v1.ts && ./scripts/compile-bare.ts compile schemas/file-system-driver/v2.bare -o dist/schemas/file-system-driver/v2.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v1.bare -o dist/schemas/actor-persist/v1.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v2.bare -o dist/schemas/actor-persist/v2.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v3.bare -o dist/schemas/actor-persist/v3.ts", "check-types": "tsc --noEmit", "test": "vitest run", "test:watch": "vitest", diff --git a/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v3.bare b/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v3.bare index 1a362c9794..9bbd047387 100644 --- a/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v3.bare +++ b/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v3.bare @@ -1,10 +1,15 @@ # MARK: Connection +type Subscription struct { + eventName: str +} + # Connection associated with hibernatable WebSocket that should persist across lifecycles. -type PersistedHibernatableConn struct { +type HibernatableConn struct { # Connection ID generated by RivetKit id: str parameters: data state: data + subscriptions: list # Request ID of the hibernatable WebSocket hibernatableRequestId: data @@ -15,7 +20,7 @@ type PersistedHibernatableConn struct { } # MARK: Schedule Event -type PersistedScheduleEvent struct { +type ScheduleEvent struct { eventId: str timestamp: i64 action: str @@ -23,11 +28,11 @@ type PersistedScheduleEvent struct { } # MARK: Actor -type PersistedActor struct { +type Actor struct { # Input data passed to the actor on initialization input: optional hasInitialized: bool state: data - hibernatableConns: list - scheduledEvents: list + hibernatableConns: list + scheduledEvents: list } diff --git a/rivetkit-typescript/packages/rivetkit/schemas/file-system-driver/v2.bare b/rivetkit-typescript/packages/rivetkit/schemas/file-system-driver/v2.bare new file mode 100644 index 0000000000..73ba10773d --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/schemas/file-system-driver/v2.bare @@ -0,0 +1,23 @@ +# File System Driver Schema (v2) + +# MARK: Actor State +type ActorKvEntry struct { + key: data + value: data +} + +type ActorState struct { + actorId: str + name: str + key: list + # KV storage map for actor and connection data + # Keys are strings (base64 encoded), values are byte arrays + kvStorage: list + createdAt: u64 +} + +# MARK: Actor Alarm +type ActorAlarm struct { + actorId: str + timestamp: uint +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn.ts index 9f108c0b2c..3d4170cee9 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn.ts @@ -1,4 +1,6 @@ import * as cbor from "cbor-x"; +import onChange from "on-change"; +import { isCborSerializable } from "@/common/utils"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils"; @@ -44,7 +46,13 @@ export class Conn { * This will only be persisted if using hibernatable WebSockets. If not, * this is just used to hole state. */ - __persist: PersistedConn; + __persist!: PersistedConn; + + /** Raw persist object without the proxy wrapper */ + #persistRaw: PersistedConn; + + /** Track if this connection's state has changed */ + #changed = false; get __driverState(): ConnDriverState | undefined { return this.__socket?.driverState; @@ -103,9 +111,9 @@ export class Conn { return false; } return ( - this.#actor[PERSIST_SYMBOL].hibernatableWebSocket.findIndex((x) => + this.#actor[PERSIST_SYMBOL].hibernatableConns.findIndex((conn) => arrayBuffersEqual( - x.requestId, + conn.hibernatableRequestId, this.__persist.hibernatableRequestId!, ), ) > -1 @@ -131,7 +139,80 @@ export class Conn { persist: PersistedConn, ) { this.#actor = actor; - this.__persist = persist; + this.#persistRaw = persist; + this.#setupPersistProxy(persist); + } + + /** + * Sets up the proxy for connection persistence with change tracking + */ + #setupPersistProxy(persist: PersistedConn) { + // If this can't be proxied, return raw value + if (persist === null || typeof persist !== "object") { + this.__persist = persist; + return; + } + + // Listen for changes to the object + this.__persist = onChange( + persist, + ( + path: string, + value: any, + _previousValue: any, + _applyData: any, + ) => { + // Validate CBOR serializability for state changes + if (path.startsWith("state")) { + let invalidPath = ""; + if ( + !isCborSerializable( + value, + (invalidPathPart: string) => { + invalidPath = invalidPathPart; + }, + "", + ) + ) { + throw new errors.InvalidStateType({ + path: path + (invalidPath ? `.${invalidPath}` : ""), + }); + } + } + + this.#changed = true; + this.#actor.rLog.debug({ + msg: "conn onChange triggered", + connId: this.id, + path, + }); + + // Notify actor that this connection has changed + this.#actor.__markConnChanged(this); + }, + { ignoreDetached: true }, + ); + } + + /** + * Returns whether this connection has unsaved changes + */ + get hasChanges(): boolean { + return this.#changed; + } + + /** + * Marks changes as saved + */ + markSaved() { + this.#changed = false; + } + + /** + * Gets the raw persist data for serialization + */ + get persistRaw(): PersistedConn { + return this.#persistRaw; } #validateStateEnabled() { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/driver.ts b/rivetkit-typescript/packages/rivetkit/src/actor/driver.ts index 94b8aafe7f..bdea8d7abc 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/driver.ts @@ -19,10 +19,27 @@ export interface ActorDriver { getContext(actorId: string): unknown; - readPersistedData(actorId: string): Promise; + // Batch KV operations + /** Batch write multiple key-value pairs. Keys and values are Uint8Arrays. */ + kvBatchPut( + actorId: string, + entries: [Uint8Array, Uint8Array][], + ): Promise; - /** ActorInstance ensure that only one instance of writePersistedData is called in parallel at a time. */ - writePersistedData(actorId: string, data: Uint8Array): Promise; + /** Batch read multiple keys. Returns null for keys that don't exist. */ + kvBatchGet( + actorId: string, + keys: Uint8Array[], + ): Promise<(Uint8Array | null)[]>; + + /** Batch delete multiple keys. */ + kvBatchDelete(actorId: string, keys: Uint8Array[]): Promise; + + /** List all keys with a given prefix. */ + kvListPrefix( + actorId: string, + prefix: Uint8Array, + ): Promise<[Uint8Array, Uint8Array][]>; // Schedule /** ActorInstance ensure that only one instance of setAlarm is called in parallel at a time. */ diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance.ts index 12dd4ff854..685a73f33e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance.ts @@ -1,24 +1,21 @@ import * as cbor from "cbor-x"; -import type { SSEStreamingApi } from "hono/streaming"; -import type { WSContext } from "hono/ws"; import invariant from "invariant"; import onChange from "on-change"; -import type { ActorKey, Encoding } from "@/actor/mod"; +import type { ActorKey } from "@/actor/mod"; import type { Client } from "@/client/client"; import { getBaseLogger, getIncludeTarget, type Logger } from "@/common/log"; import { isCborSerializable, stringifyError } from "@/common/utils"; import type { UniversalWebSocket } from "@/common/websocket-interface"; import { ActorInspector } from "@/inspector/actor"; import type { Registry } from "@/mod"; -import type * as bareSchema from "@/schemas/actor-persist/mod"; -import { PERSISTED_ACTOR_VERSIONED } from "@/schemas/actor-persist/versioned"; +import type * as persistSchema from "@/schemas/actor-persist/mod"; +import { ACTOR_VERSIONED } from "@/schemas/actor-persist/versioned"; import type * as protocol from "@/schemas/client-protocol/mod"; import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned"; import { arrayBuffersEqual, bufferToArrayBuffer, EXTRA_ERROR_LOG, - getEnvUniversal, idToStr, promiseWithResolvers, SinglePromiseQueue, @@ -28,9 +25,7 @@ import type { ActorConfig, OnConnectOptions } from "./config"; import { Conn, type ConnId, generateConnRequestId } from "./conn"; import { CONN_DRIVERS, - type ConnDriver, ConnDriverKind, - type ConnDriverState, getConnDriverKindFromState, } from "./conn-drivers"; import type { ConnSocket } from "./conn-socket"; @@ -39,10 +34,11 @@ import type { AnyDatabaseProvider, InferDatabaseClient } from "./database"; import type { ActorDriver } from "./driver"; import * as errors from "./errors"; import { serializeActorKey } from "./keys"; +import { KEYS, makeConnKey } from "./kv"; import type { PersistedActor, PersistedConn, - PersistedHibernatableWebSocket, + PersistedHibernatableConn, PersistedScheduleEvent, } from "./persisted"; import { processMessage } from "./protocol/old"; @@ -146,9 +142,19 @@ export class ActorInstance { /** Actor log, intended for the user to call */ #log!: Logger; + get log(): Logger { + invariant(this.#log, "log not configured"); + return this.#log; + } + /** Runtime log, intended for internal actor logs */ #rLog!: Logger; + get rLog(): Logger { + invariant(this.#rLog, "log not configured"); + return this.#rLog; + } + #sleepCalled = false; #stopCalled = false; @@ -182,18 +188,42 @@ export class ActorInstance { #vars?: V; #backgroundPromises: Promise[] = []; + #abortController = new AbortController(); + #config: ActorConfig; #actorDriver!: ActorDriver; #inlineClient!: Client>; #actorId!: string; + #name!: string; + + get name(): string { + return this.#name; + } + #key!: ActorKey; + + get key(): ActorKey { + return this.#key; + } + #region!: string; + + get region(): string { + return this.#region; + } + #ready = false; #connections = new Map>(); + + get conns(): Map> { + return this.#connections; + } + #subscriptionIndex = new Map>>(); + #changedConnections = new Set(); #sleepTimeout?: NodeJS.Timeout; @@ -205,8 +235,24 @@ export class ActorInstance { #activeRawWebSockets = new Set(); #schedule!: Schedule; + + get schedule(): Schedule { + return this.#schedule; + } + #db!: InferDatabaseClient; + /** + * Gets the database. + * @experimental + */ + get db(): InferDatabaseClient { + if (!this.#db) { + throw new errors.DatabaseNotEnabled(); + } + return this.#db; + } + #inspector = new ActorInspector(() => { return { isDbEnabled: async () => { @@ -311,6 +357,7 @@ export class ActorInstance { this.actorContext = new ActorContext(this); } + // MARK: Initialization async start( actorDriver: ActorDriver, inlineClient: Client>, @@ -349,10 +396,118 @@ export class ActorInstance { this.#region = region; this.#schedule = new Schedule(this); - // Initialize server - // - // Store the promise so network requests can await initialization - await this.#initialize(); + // Read initial state from KV storage + const [persistDataBuffer] = await this.#actorDriver.kvBatchGet( + this.#actorId, + [KEYS.PERSIST_DATA], + ); + invariant( + persistDataBuffer !== null, + "persist data has not been set, it should be set when initialized", + ); + const bareData = + ACTOR_VERSIONED.deserializeWithEmbeddedVersion(persistDataBuffer); + const persistData = this.#convertFromBarePersisted(bareData); + + if (persistData.hasInitialized) { + // List all connection keys + const connEntries = await this.#actorDriver.kvListPrefix( + this.#actorId, + KEYS.CONN_PREFIX, + ); + + // Decode connections + const connections: PersistedConn[] = []; + for (const [_key, value] of connEntries) { + try { + const conn = cbor.decode(value) as PersistedConn; + connections.push(conn); + } catch (error) { + this.#rLog.error({ + msg: "failed to decode connection", + error: stringifyError(error), + }); + } + } + + this.#rLog.info({ + msg: "actor restoring", + connections: connections.length, + hibernatableWebSockets: persistData.hibernatableConns.length, + }); + + // Set initial state + this.#initPersistProxy(persistData); + + // Create connection instances + for (const connPersist of connections) { + // Create connections + const conn = new Conn(this, connPersist); + this.#connections.set(conn.id, conn); + + // Register event subscriptions + for (const sub of connPersist.subscriptions) { + this.#addSubscription(sub.eventName, conn, true); + } + } + } else { + this.#rLog.info({ msg: "actor creating" }); + + // Initialize actor state + let stateData: unknown; + if (this.stateEnabled) { + this.#rLog.info({ msg: "actor state initializing" }); + + if ("createState" in this.#config) { + this.#config.createState; + + // Convert state to undefined since state is not defined yet here + stateData = await this.#config.createState( + this.actorContext as unknown as ActorContext< + undefined, + undefined, + undefined, + undefined, + undefined, + undefined + >, + persistData.input!, + ); + } else if ("state" in this.#config) { + stateData = structuredClone(this.#config.state); + } else { + throw new Error( + "Both 'createState' or 'state' were not defined", + ); + } + } else { + this.#rLog.debug({ msg: "state not enabled" }); + } + + // Save state and mark as initialized + persistData.state = stateData as S; + persistData.hasInitialized = true; + + // Update state + this.#rLog.debug({ msg: "writing state" }); + const bareData = this.#convertToBarePersisted(persistData); + await this.#actorDriver.kvBatchPut(this.#actorId, [ + [ + KEYS.PERSIST_DATA, + ACTOR_VERSIONED.serializeWithEmbeddedVersion(bareData), + ], + ]); + + this.#initPersistProxy(persistData); + + // Notify creation + if (this.#config.onCreate) { + await this.#config.onCreate( + this.actorContext, + persistData.input!, + ); + } + } // TODO: Exit process if this errors if (this.#varsEnabled) { @@ -424,155 +579,275 @@ export class ActorInstance { await this._onAlarm(); } - async #scheduleEventInner(newEvent: PersistedScheduleEvent) { - this.actorContext.log.info({ msg: "scheduling event", ...newEvent }); - - // Insert event in to index - const insertIndex = this.#persist.scheduledEvents.findIndex( - (x) => x.timestamp > newEvent.timestamp, - ); - if (insertIndex === -1) { - this.#persist.scheduledEvents.push(newEvent); - } else { - this.#persist.scheduledEvents.splice(insertIndex, 0, newEvent); - } - - // Update alarm if: - // - this is the newest event (i.e. at beginning of array) or - // - this is the only event (i.e. the only event in the array) - if (insertIndex === 0 || this.#persist.scheduledEvents.length === 1) { - this.actorContext.log.info({ - msg: "setting alarm", - timestamp: newEvent.timestamp, - eventCount: this.#persist.scheduledEvents.length, - }); - await this.#queueSetAlarm(newEvent.timestamp); - } + #assertReady(allowStoppingState: boolean = false) { + if (!this.#ready) throw new errors.InternalError("Actor not ready"); + if (!allowStoppingState && this.#stopCalled) + throw new errors.InternalError("Actor is stopping"); } /** - * Triggers any pending alarms. - * - * This method is idempotent. It's called automatically when the actor wakes - * in order to trigger any pending alarms. + * Check if the actor is ready to handle requests. */ - async _onAlarm() { - const now = Date.now(); - this.actorContext.log.debug({ - msg: "alarm triggered", - now, - events: this.#persist.scheduledEvents.length, - }); - - // Update sleep - // - // Do this before any async logic - this.#resetSleepTimer(); + isReady(): boolean { + return this.#ready; + } - // Remove events from schedule that we're about to run - const runIndex = this.#persist.scheduledEvents.findIndex( - (x) => x.timestamp <= now, - ); - if (runIndex === -1) { - // This method is idempotent, so this will happen in scenarios like `start` and - // no events are pending. - this.#rLog.debug({ msg: "no events are due yet" }); - if (this.#persist.scheduledEvents.length > 0) { - const nextTs = this.#persist.scheduledEvents[0].timestamp; - this.actorContext.log.debug({ - msg: "alarm fired early, rescheduling for next event", - now, - nextTs, - delta: nextTs - now, - }); - await this.#queueSetAlarm(nextTs); - } - this.actorContext.log.debug({ msg: "no events to run", now }); + // MARK: Stop + /** + * For the engine: + * 1. Engine runner receives CommandStopActor + * 2. Engine runner calls _onStop and waits for it to finish + * 3. Engine runner publishes EventActorStateUpdate with ActorStateSTop + */ + async _onStop() { + if (this.#stopCalled) { + this.#rLog.warn({ msg: "already stopping actor" }); return; } - const scheduleEvents = this.#persist.scheduledEvents.splice( - 0, - runIndex + 1, - ); - this.actorContext.log.debug({ - msg: "running events", - count: scheduleEvents.length, - }); - - // Set alarm for next event - if (this.#persist.scheduledEvents.length > 0) { - const nextTs = this.#persist.scheduledEvents[0].timestamp; - this.actorContext.log.info({ - msg: "setting next alarm", - nextTs, - remainingEvents: this.#persist.scheduledEvents.length, - }); - await this.#queueSetAlarm(nextTs); - } + this.#stopCalled = true; - // Iterate by event key in order to ensure we call the events in order - for (const event of scheduleEvents) { - try { - this.actorContext.log.info({ - msg: "running action for event", - event: event.eventId, - timestamp: event.timestamp, - action: event.kind.generic.actionName, - }); + this.#rLog.info({ msg: "actor stopping" }); - // Look up function - const fn: unknown = - this.#config.actions[event.kind.generic.actionName]; + if (this.#sleepTimeout) { + clearTimeout(this.#sleepTimeout); + this.#sleepTimeout = undefined; + } - if (!fn) - throw new Error( - `Missing action for alarm ${event.kind.generic.actionName}`, - ); - if (typeof fn !== "function") - throw new Error( - `Alarm function lookup for ${event.kind.generic.actionName} returned ${typeof fn}`, - ); + // Abort any listeners waiting for shutdown + try { + this.#abortController.abort(); + } catch {} - // Call function - try { - const args = event.kind.generic.args - ? cbor.decode(new Uint8Array(event.kind.generic.args)) - : []; - await fn.call(undefined, this.actorContext, ...args); - } catch (error) { - this.actorContext.log.error({ - msg: "error while running event", + // Call onStop lifecycle hook if defined + if (this.#config.onStop) { + try { + this.#rLog.debug({ msg: "calling onStop" }); + const result = this.#config.onStop(this.actorContext); + if (result instanceof Promise) { + await deadline(result, this.#config.options.onStopTimeout); + } + this.#rLog.debug({ msg: "onStop completed" }); + } catch (error) { + if (error instanceof DeadlineError) { + this.#rLog.error({ msg: "onStop timed out" }); + } else { + this.#rLog.error({ + msg: "error in onStop", error: stringifyError(error), - event: event.eventId, - timestamp: event.timestamp, - action: event.kind.generic.actionName, }); } - } catch (error) { - this.actorContext.log.error({ - msg: "internal error while running event", - error: stringifyError(error), - ...event, + } + } + + const promises: Promise[] = []; + + // Disconnect existing non-hibernatable connections + for (const connection of this.#connections.values()) { + if (!connection.isHibernatable) { + this.#rLog.debug({ + msg: "disconnecting non-hibernatable connection on actor stop", + connId: connection.id, }); + promises.push(connection.disconnect()); } + + // TODO: Figure out how to abort HTTP requests on shutdown. This + // might already be handled by the engine runner tunnel shutdown. } + + // Wait for any background tasks to finish, with timeout + await this.#waitBackgroundPromises( + this.#config.options.waitUntilTimeout, + ); + + // Clear timeouts + if (this.#pendingSaveTimeout) clearTimeout(this.#pendingSaveTimeout); + + // Write state + await this.saveState({ immediate: true, allowStoppingState: true }); + + // Await all `close` event listeners with 1.5 second timeout + const res = Promise.race([ + Promise.all(promises).then(() => false), + new Promise((res) => + globalThis.setTimeout(() => res(true), 1500), + ), + ]); + + if (await res) { + this.#rLog.warn({ + msg: "timed out waiting for connections to close, shutting down anyway", + }); + } + + // Wait for queues to finish + if (this.#persistWriteQueue.runningDrainLoop) + await this.#persistWriteQueue.runningDrainLoop; + if (this.#alarmWriteQueue.runningDrainLoop) + await this.#alarmWriteQueue.runningDrainLoop; } - async scheduleEvent( - timestamp: number, - action: string, - args: unknown[], - ): Promise { - return this.#scheduleEventInner({ - eventId: crypto.randomUUID(), - timestamp, - kind: { - generic: { - actionName: action, - args: bufferToArrayBuffer(cbor.encode(args)), - }, - }, + /** Abort signal that fires when the actor is stopping. */ + get abortSignal(): AbortSignal { + return this.#abortController.signal; + } + + // MARK: Sleep + /** + * Reset timer from the last actor interaction that allows it to be put to sleep. + * + * This should be called any time a sleep-related event happens: + * - Connection opens (will clear timer) + * - Connection closes (will schedule timer if there are no open connections) + * - Alarm triggers (will reset timer) + * + * We don't need to call this on events like individual action calls, since there will always be a connection open for these. + **/ + #resetSleepTimer() { + if (this.#config.options.noSleep || !this.#sleepingSupported) return; + + // Don't sleep if already stopping + if (this.#stopCalled) return; + + const canSleep = this.#canSleep(); + + this.#rLog.debug({ + msg: "resetting sleep timer", + canSleep: CanSleep[canSleep], + existingTimeout: !!this.#sleepTimeout, + timeout: this.#config.options.sleepTimeout, }); + + if (this.#sleepTimeout) { + clearTimeout(this.#sleepTimeout); + this.#sleepTimeout = undefined; + } + + // Don't set a new timer if already sleeping + if (this.#sleepCalled) return; + + if (canSleep === CanSleep.Yes) { + this.#sleepTimeout = setTimeout(() => { + this._startSleep(); + }, this.#config.options.sleepTimeout); + } + } + + /** If this actor can be put in a sleeping state. */ + #canSleep(): CanSleep { + if (!this.#ready) return CanSleep.NotReady; + + // Do not sleep if Hono HTTP requests are in-flight + if (this.#activeHonoHttpRequests > 0) + return CanSleep.ActiveHonoHttpRequests; + + // TODO: When WS hibernation is ready, update this to only count non-hibernatable websockets + // Do not sleep if there are raw websockets open + if (this.#activeRawWebSockets.size > 0) + return CanSleep.ActiveRawWebSockets; + + // Check for active conns. This will also cover active actions, since all actions have a connection. + for (const conn of this.#connections.values()) { + // TODO: Enable this when hibernation is implemented. We're waiting on support for Guard to not auto-wake the actor if it sleeps. + // if (!conn.isHibernatable) + // return false; + + // if (!conn.isHibernatable) return CanSleep.ActiveConns; + return CanSleep.ActiveConns; + } + + return CanSleep.Yes; + } + + /** + * Puts an actor to sleep. This should just start the sleep sequence, most shutdown logic should be in _stop (which is called by the ActorDriver when sleeping). + * + * For the engine, this will: + * 1. Publish EventActorIntent with ActorIntentSleep (via driver.startSleep) + * 2. Engine runner will wait for CommandStopActor + * 3. Engine runner will call _onStop and wait for it to finish + * 4. Engine runner will publish EventActorStateUpdate with ActorStateSTop + **/ + _startSleep() { + if (this.#stopCalled) { + this.#rLog.debug({ + msg: "cannot call _startSleep if actor already stopping", + }); + return; + } + + // IMPORTANT: #sleepCalled should have no effect on the actor's + // behavior aside from preventing calling _startSleep twice. Wait for + // `_onStop` before putting in a stopping state. + if (this.#sleepCalled) { + this.#rLog.warn({ + msg: "cannot call _startSleep twice, actor already sleeping", + }); + return; + } + this.#sleepCalled = true; + + // NOTE: Publishes ActorIntentSleep + const sleep = this.#actorDriver.startSleep?.bind( + this.#actorDriver, + this.#actorId, + ); + invariant(this.#sleepingSupported, "sleeping not supported"); + invariant(sleep, "no sleep on driver"); + + this.#rLog.info({ msg: "actor sleeping" }); + + // Schedule sleep to happen on the next tick. This allows for any action that calls _sleep to complete. + setImmediate(() => { + // The actor driver should call stop when ready to stop + // + // This will call _stop once Pegboard responds with the new status + sleep(); + }); + } + + /** + * Called by router middleware when an HTTP request begins. + */ + __beginHonoHttpRequest() { + this.#activeHonoHttpRequests++; + this.#resetSleepTimer(); + } + + /** + * Called by router middleware when an HTTP request ends. + */ + __endHonoHttpRequest() { + this.#activeHonoHttpRequests--; + if (this.#activeHonoHttpRequests < 0) { + this.#activeHonoHttpRequests = 0; + this.#rLog.warn({ + msg: "active hono requests went below 0, this is a RivetKit bug", + ...EXTRA_ERROR_LOG, + }); + } + this.#resetSleepTimer(); + } + + // MARK: State + /** + * Gets the current state. + * + * Changing properties of this value will automatically be persisted. + */ + get state(): S { + this.#validateStateEnabled(); + return this.#persist.state; + } + + /** + * Sets the current state. + * + * This property will automatically be persisted. + */ + set state(value: S) { + this.#validateStateEnabled(); + this.#persist.state = value; } get stateEnabled() { @@ -589,6 +864,12 @@ export class ActorInstance { return "createConnState" in this.#config || "connState" in this.#config; } + get vars(): V { + this.#validateVarsEnabled(); + invariant(this.#vars !== undefined, "vars not enabled"); + return this.#vars; + } + get #varsEnabled() { return "createVars" in this.#config || "vars" in this.#config; } @@ -599,6 +880,43 @@ export class ActorInstance { } } + /** + * Forces the state to get saved. + * + * This is helpful if running a long task that may fail later or when + * running a background job that updates the state. + * + * @param opts - Options for saving the state. + */ + async saveState(opts: SaveStateOptions) { + this.#assertReady(opts.allowStoppingState); + + this.#rLog.debug({ + msg: "saveState called", + persistChanged: this.#persistChanged, + allowStoppingState: opts.allowStoppingState, + immediate: opts.immediate, + }); + + if (this.#persistChanged) { + if (opts.immediate) { + // Save immediately + await this.#savePersistInner(); + } else { + // Create callback + if (!this.#onPersistSavedPromise) { + this.#onPersistSavedPromise = promiseWithResolvers(); + } + + // Save state throttled + this.#savePersistThrottled(); + + // Wait for save + await this.#onPersistSavedPromise.promise; + } + } + } + /** Promise used to wait for a save to complete. This is required since you cannot await `#saveStateThrottled`. */ #onPersistSavedPromise?: ReturnType>; @@ -627,24 +945,18 @@ export class ActorInstance { try { this.#lastSaveTime = Date.now(); - if (this.#persistChanged) { - const finished = this.#persistWriteQueue.enqueue(async () => { - this.#rLog.debug({ msg: "saving persist" }); + const hasChanges = + this.#persistChanged || this.#changedConnections.size > 0; - // There might be more changes while we're writing, so we set this - // before writing to KV in order to avoid a race condition. - this.#persistChanged = false; + if (hasChanges) { + const finished = this.#persistWriteQueue.enqueue(async () => { + this.#rLog.debug({ + msg: "saving persist", + actorChanged: this.#persistChanged, + connectionsChanged: this.#changedConnections.size, + }); - // Convert to BARE types and write to KV - const bareData = this.#convertToBarePersisted( - this.#persistRaw, - ); - await this.#actorDriver.writePersistedData( - this.#actorId, - PERSISTED_ACTOR_VERSIONED.serializeWithEmbeddedVersion( - bareData, - ), - ); + await this.#writePersistedData(); this.#rLog.debug({ msg: "persist saved" }); }); @@ -661,18 +973,48 @@ export class ActorInstance { this.#onPersistSavedPromise?.reject(error); throw error; } - } + } + + async #writePersistedData() { + const entries: [Uint8Array, Uint8Array][] = []; + + // Save actor state if changed + if (this.#persistChanged) { + this.#persistChanged = false; + + // Prepare actor state + const bareData = this.#convertToBarePersisted(this.#persistRaw); + + // Key [1] for actor persist data + entries.push([ + KEYS.PERSIST_DATA, + ACTOR_VERSIONED.serializeWithEmbeddedVersion(bareData), + ]); + } + + // Save changed connections + if (this.#changedConnections.size > 0) { + for (const connId of this.#changedConnections) { + const conn = this.#connections.get(connId); + if (conn) { + const connData = cbor.encode(conn.persistRaw); + entries.push([makeConnKey(connId), connData]); + conn.markSaved(); + } + } + this.#changedConnections.clear(); + } - async #queueSetAlarm(timestamp: number): Promise { - await this.#alarmWriteQueue.enqueue(async () => { - await this.#actorDriver.setAlarm(this, timestamp); - }); + // Write all entries in batch + if (entries.length > 0) { + await this.#actorDriver.kvBatchPut(this.#actorId, entries); + } } /** * Creates proxy for `#persist` that handles automatically flagging when state needs to be updated. */ - #setPersist(target: PersistedActor) { + #initPersistProxy(target: PersistedActor) { // Set raw persist object this.#persistRaw = target; @@ -775,107 +1117,23 @@ export class ActorInstance { ); } - async #initialize() { - // Read initial state - const persistDataBuffer = await this.#actorDriver.readPersistedData( - this.#actorId, - ); - invariant( - persistDataBuffer !== undefined, - "persist data has not been set, it should be set when initialized", - ); - const bareData = - PERSISTED_ACTOR_VERSIONED.deserializeWithEmbeddedVersion( - persistDataBuffer, - ); - const persistData = this.#convertFromBarePersisted(bareData); - - if (persistData.hasInitiated) { - this.#rLog.info({ - msg: "actor restoring", - connections: persistData.connections.length, - hibernatableWebSockets: - persistData.hibernatableWebSocket.length, - }); - - // Set initial state - this.#setPersist(persistData); - - // Load connections - for (const connPersist of this.#persist.connections) { - // Create connections - const conn = new Conn(this, connPersist); - this.#connections.set(conn.id, conn); - - // Register event subscriptions - for (const sub of connPersist.subscriptions) { - this.#addSubscription(sub.eventName, conn, true); - } - } - } else { - this.#rLog.info({ msg: "actor creating" }); - - // Initialize actor state - let stateData: unknown; - if (this.stateEnabled) { - this.#rLog.info({ msg: "actor state initializing" }); - - if ("createState" in this.#config) { - this.#config.createState; - - // Convert state to undefined since state is not defined yet here - stateData = await this.#config.createState( - this.actorContext as unknown as ActorContext< - undefined, - undefined, - undefined, - undefined, - undefined, - undefined - >, - persistData.input!, - ); - } else if ("state" in this.#config) { - stateData = structuredClone(this.#config.state); - } else { - throw new Error( - "Both 'createState' or 'state' were not defined", - ); - } - } else { - this.#rLog.debug({ msg: "state not enabled" }); - } - - // Save state and mark as initialized - persistData.state = stateData as S; - persistData.hasInitiated = true; - - // Update state - this.#rLog.debug({ msg: "writing state" }); - const bareData = this.#convertToBarePersisted(persistData); - await this.#actorDriver.writePersistedData( - this.#actorId, - PERSISTED_ACTOR_VERSIONED.serializeWithEmbeddedVersion( - bareData, - ), - ); - - this.#setPersist(persistData); - - // Notify creation - if (this.#config.onCreate) { - await this.#config.onCreate( - this.actorContext, - persistData.input!, - ); - } - } - } - + // MARK: Connections __getConnForId(id: string): Conn | undefined { return this.#connections.get(id); } + /** + * Mark a connection as changed so it will be persisted on next save + */ + __markConnChanged(conn: Conn) { + this.#changedConnections.add(conn.id); + this.#rLog.debug({ + msg: "marked connection as changed", + connId: conn.id, + totalChanged: this.#changedConnections.size, + }); + } + /** * Call when conn is disconnected. * @@ -930,22 +1188,26 @@ export class ActorInstance { * Removes a connection and cleans up its resources. */ #removeConn(conn: Conn) { - // Remove from persist & save immediately - const connIdx = this.#persist.connections.findIndex( - (c) => c.connId === conn.id, - ); - if (connIdx !== -1) { - this.#persist.connections.splice(connIdx, 1); - this.saveState({ immediate: true, allowStoppingState: true }); - } else { - this.#rLog.warn({ - msg: "could not find persisted connection to remove", - connId: conn.id, + // Remove conn from KV + const key = makeConnKey(conn.id); + this.#actorDriver + .kvBatchDelete(this.#actorId, [key]) + .then(() => { + this.#rLog.debug({ + msg: "removed connection from KV", + connId: conn.id, + }); + }) + .catch((err) => { + this.#rLog.error({ + msg: "kvBatchDelete failed for conn", + err: stringifyError(err), + }); }); - } - // Remove from state + // Remove from state and tracking this.#connections.delete(conn.id); + this.#changedConnections.delete(conn.id); this.#rLog.debug({ msg: "removed conn", connId: conn.id }); // Remove subscriptions @@ -1119,8 +1381,11 @@ export class ActorInstance { // Check if this connection is for a hibernatable websocket if (socket.requestIdBuf) { const isHibernatable = - this.#persist.hibernatableWebSocket.findIndex((ws) => - arrayBuffersEqual(ws.requestId, socket.requestIdBuf!), + this.#persist.hibernatableConns.findIndex((conn) => + arrayBuffersEqual( + conn.hibernatableRequestId, + socket.requestIdBuf!, + ), ) !== -1; if (isHibernatable) { @@ -1137,8 +1402,9 @@ export class ActorInstance { // Do this immediately after adding connection & before any async logic in order to avoid race conditions with sleep timeouts this.#resetSleepTimer(); - // Add to persistence & save immediately - this.#persist.connections.push(persist); + // Mark connection as changed for batch save + this.#changedConnections.add(conn.id); + this.saveState({ immediate: true }); // Handle connection @@ -1221,97 +1487,7 @@ export class ActorInstance { }); } - // MARK: Events - #addSubscription( - eventName: string, - connection: Conn, - fromPersist: boolean, - ) { - if (connection.subscriptions.has(eventName)) { - this.#rLog.debug({ - msg: "connection already has subscription", - eventName, - }); - return; - } - - // Persist subscriptions & save immediately - // - // Don't update persistence if already restoring from persistence - if (!fromPersist) { - connection.__persist.subscriptions.push({ eventName: eventName }); - this.saveState({ immediate: true }); - } - - // Update subscriptions - connection.subscriptions.add(eventName); - - // Update subscription index - let subscribers = this.#subscriptionIndex.get(eventName); - if (!subscribers) { - subscribers = new Set(); - this.#subscriptionIndex.set(eventName, subscribers); - } - subscribers.add(connection); - } - - #removeSubscription( - eventName: string, - connection: Conn, - fromRemoveConn: boolean, - ) { - if (!connection.subscriptions.has(eventName)) { - this.#rLog.warn({ - msg: "connection does not have subscription", - eventName, - }); - return; - } - - // Persist subscriptions & save immediately - // - // Don't update the connection itself if the connection is already being removed - if (!fromRemoveConn) { - connection.subscriptions.delete(eventName); - - const subIdx = connection.__persist.subscriptions.findIndex( - (s) => s.eventName === eventName, - ); - if (subIdx !== -1) { - connection.__persist.subscriptions.splice(subIdx, 1); - } else { - this.#rLog.warn({ - msg: "subscription does not exist with name", - eventName, - }); - } - - this.saveState({ immediate: true }); - } - - // Update scriptions index - const subscribers = this.#subscriptionIndex.get(eventName); - if (subscribers) { - subscribers.delete(connection); - if (subscribers.size === 0) { - this.#subscriptionIndex.delete(eventName); - } - } - } - - #assertReady(allowStoppingState: boolean = false) { - if (!this.#ready) throw new errors.InternalError("Actor not ready"); - if (!allowStoppingState && this.#stopCalled) - throw new errors.InternalError("Actor is stopping"); - } - - /** - * Check if the actor is ready to handle requests. - */ - isReady(): boolean { - return this.#ready; - } - + // MARK: Actions /** * Execute an action call from a client. * @@ -1514,7 +1690,7 @@ export class ActorInstance { // Track hibernatable WebSockets let rivetRequestId: ArrayBuffer | undefined; let persistedHibernatableWebSocket: - | PersistedHibernatableWebSocket + | PersistedHibernatableConn | undefined; const onSocketOpened = (event: any) => { @@ -1524,16 +1700,16 @@ export class ActorInstance { if (rivetRequestId) { const rivetRequestIdLocal = rivetRequestId; persistedHibernatableWebSocket = - this.#persist.hibernatableWebSocket.find((ws) => + this.#persist.hibernatableConns.find((conn) => arrayBuffersEqual( - ws.requestId, + conn.hibernatableRequestId, rivetRequestIdLocal, ), ); if (persistedHibernatableWebSocket) { persistedHibernatableWebSocket.lastSeenTimestamp = - BigInt(Date.now()); + Date.now(); } } @@ -1549,12 +1725,10 @@ export class ActorInstance { const onSocketMessage = (event: any) => { // Update state of hibernatable WS if (persistedHibernatableWebSocket) { - persistedHibernatableWebSocket.lastSeenTimestamp = BigInt( - Date.now(), - ); - persistedHibernatableWebSocket.msgIndex = BigInt( - event.rivetMessageIndex, - ); + persistedHibernatableWebSocket.lastSeenTimestamp = + Date.now(); + persistedHibernatableWebSocket.msgIndex = + event.rivetMessageIndex; } this.#rLog.debug({ @@ -1570,15 +1744,15 @@ export class ActorInstance { // Remove hibernatable WS if (rivetRequestId) { const rivetRequestIdLocal = rivetRequestId; - const wsIndex = - this.#persist.hibernatableWebSocket.findIndex((ws) => + const wsIndex = this.#persist.hibernatableConns.findIndex( + (conn) => arrayBuffersEqual( - ws.requestId, + conn.hibernatableRequestId, rivetRequestIdLocal, ), - ); + ); - const removed = this.#persist.hibernatableWebSocket.splice( + const removed = this.#persist.hibernatableConns.splice( wsIndex, 1, ); @@ -1604,7 +1778,7 @@ export class ActorInstance { rivetRequestId, isHibernatable: !!persistedHibernatableWebSocket, hibernatableWebSocketCount: - this.#persist.hibernatableWebSocket.length, + this.#persist.hibernatableConns.length, }); // Remove listener and socket from tracking @@ -1643,90 +1817,89 @@ export class ActorInstance { } } - // MARK: Lifecycle hooks + // MARK: Events + #addSubscription( + eventName: string, + connection: Conn, + fromPersist: boolean, + ) { + if (connection.subscriptions.has(eventName)) { + this.#rLog.debug({ + msg: "connection already has subscription", + eventName, + }); + return; + } - // MARK: Exposed methods - get log(): Logger { - invariant(this.#log, "log not configured"); - return this.#log; - } + // Persist subscriptions & save immediately + // + // Don't update persistence if already restoring from persistence + if (!fromPersist) { + connection.__persist.subscriptions.push({ eventName: eventName }); - get rLog(): Logger { - invariant(this.#rLog, "log not configured"); - return this.#rLog; - } + // Mark connection as changed + this.#changedConnections.add(connection.id); - /** - * Gets the name. - */ - get name(): string { - return this.#name; - } + this.saveState({ immediate: true }); + } - /** - * Gets the key. - */ - get key(): ActorKey { - return this.#key; - } + // Update subscriptions + connection.subscriptions.add(eventName); - /** - * Gets the region. - */ - get region(): string { - return this.#region; + // Update subscription index + let subscribers = this.#subscriptionIndex.get(eventName); + if (!subscribers) { + subscribers = new Set(); + this.#subscriptionIndex.set(eventName, subscribers); + } + subscribers.add(connection); } - /** - * Gets the scheduler. - */ - get schedule(): Schedule { - return this.#schedule; - } + #removeSubscription( + eventName: string, + connection: Conn, + fromRemoveConn: boolean, + ) { + if (!connection.subscriptions.has(eventName)) { + this.#rLog.warn({ + msg: "connection does not have subscription", + eventName, + }); + return; + } - /** - * Gets the map of connections. - */ - get conns(): Map> { - return this.#connections; - } + // Persist subscriptions & save immediately + // + // Don't update the connection itself if the connection is already being removed + if (!fromRemoveConn) { + connection.subscriptions.delete(eventName); - /** - * Gets the current state. - * - * Changing properties of this value will automatically be persisted. - */ - get state(): S { - this.#validateStateEnabled(); - return this.#persist.state; - } + const subIdx = connection.__persist.subscriptions.findIndex( + (s) => s.eventName === eventName, + ); + if (subIdx !== -1) { + connection.__persist.subscriptions.splice(subIdx, 1); + } else { + this.#rLog.warn({ + msg: "subscription does not exist with name", + eventName, + }); + } - /** - * Gets the database. - * @experimental - * @throws {DatabaseNotEnabled} If the database is not enabled. - */ - get db(): InferDatabaseClient { - if (!this.#db) { - throw new errors.DatabaseNotEnabled(); - } - return this.#db; - } + // Mark connection as changed + this.#changedConnections.add(connection.id); - /** - * Sets the current state. - * - * This property will automatically be persisted. - */ - set state(value: S) { - this.#validateStateEnabled(); - this.#persist.state = value; - } + this.saveState({ immediate: true }); + } - get vars(): V { - this.#validateVarsEnabled(); - invariant(this.#vars !== undefined, "vars not enabled"); - return this.#vars; + // Update scriptions index + const subscribers = this.#subscriptionIndex.get(eventName); + if (subscribers) { + subscribers.delete(connection); + if (subscribers.size === 0) { + this.#subscriptionIndex.delete(eventName); + } + } } /** @@ -1766,306 +1939,158 @@ export class ActorInstance { } } - /** - * Prevents the actor from sleeping until promise is complete. - * - * This allows the actor runtime to ensure that a promise completes while - * returning from an action request early. - * - * @param promise - The promise to run in the background. - */ - _waitUntil(promise: Promise) { - this.#assertReady(); - - // TODO: Should we force save the state? - // Add logging to promise and make it non-failable - const nonfailablePromise = promise - .then(() => { - this.#rLog.debug({ msg: "wait until promise complete" }); - }) - .catch((error) => { - this.#rLog.error({ - msg: "wait until promise failed", - error: stringifyError(error), - }); - }); - this.#backgroundPromises.push(nonfailablePromise); - } - - /** - * Forces the state to get saved. - * - * This is helpful if running a long task that may fail later or when - * running a background job that updates the state. - * - * @param opts - Options for saving the state. - */ - async saveState(opts: SaveStateOptions) { - this.#assertReady(opts.allowStoppingState); - - this.#rLog.debug({ - msg: "saveState called", - persistChanged: this.#persistChanged, - allowStoppingState: opts.allowStoppingState, - immediate: opts.immediate, - }); - - if (this.#persistChanged) { - if (opts.immediate) { - // Save immediately - await this.#savePersistInner(); - } else { - // Create callback - if (!this.#onPersistSavedPromise) { - this.#onPersistSavedPromise = promiseWithResolvers(); - } - - // Save state throttled - this.#savePersistThrottled(); - - // Wait for save - await this.#onPersistSavedPromise.promise; - } - } - } - - /** - * Called by router middleware when an HTTP request begins. - */ - __beginHonoHttpRequest() { - this.#activeHonoHttpRequests++; - this.#resetSleepTimer(); - } - - /** - * Called by router middleware when an HTTP request ends. - */ - __endHonoHttpRequest() { - this.#activeHonoHttpRequests--; - if (this.#activeHonoHttpRequests < 0) { - this.#activeHonoHttpRequests = 0; - this.#rLog.warn({ - msg: "active hono requests went below 0, this is a RivetKit bug", - ...EXTRA_ERROR_LOG, - }); - } - this.#resetSleepTimer(); - } - - // MARK: Sleep - /** - * Reset timer from the last actor interaction that allows it to be put to sleep. - * - * This should be called any time a sleep-related event happens: - * - Connection opens (will clear timer) - * - Connection closes (will schedule timer if there are no open connections) - * - Alarm triggers (will reset timer) - * - * We don't need to call this on events like individual action calls, since there will always be a connection open for these. - **/ - #resetSleepTimer() { - if (this.#config.options.noSleep || !this.#sleepingSupported) return; - - // Don't sleep if already stopping - if (this.#stopCalled) return; - - const canSleep = this.#canSleep(); - - this.#rLog.debug({ - msg: "resetting sleep timer", - canSleep: CanSleep[canSleep], - existingTimeout: !!this.#sleepTimeout, - timeout: this.#config.options.sleepTimeout, - }); - - if (this.#sleepTimeout) { - clearTimeout(this.#sleepTimeout); - this.#sleepTimeout = undefined; - } - - // Don't set a new timer if already sleeping - if (this.#sleepCalled) return; - - if (canSleep === CanSleep.Yes) { - this.#sleepTimeout = setTimeout(() => { - this._startSleep(); - }, this.#config.options.sleepTimeout); - } - } - - /** If this actor can be put in a sleeping state. */ - #canSleep(): CanSleep { - if (!this.#ready) return CanSleep.NotReady; - - // Do not sleep if Hono HTTP requests are in-flight - if (this.#activeHonoHttpRequests > 0) - return CanSleep.ActiveHonoHttpRequests; - - // TODO: When WS hibernation is ready, update this to only count non-hibernatable websockets - // Do not sleep if there are raw websockets open - if (this.#activeRawWebSockets.size > 0) - return CanSleep.ActiveRawWebSockets; - - // Check for active conns. This will also cover active actions, since all actions have a connection. - for (const conn of this.#connections.values()) { - // TODO: Enable this when hibernation is implemented. We're waiting on support for Guard to not auto-wake the actor if it sleeps. - // if (!conn.isHibernatable) - // return false; - - // if (!conn.isHibernatable) return CanSleep.ActiveConns; - return CanSleep.ActiveConns; - } - - return CanSleep.Yes; - } + // MARK: Alarms + async #scheduleEventInner(newEvent: PersistedScheduleEvent) { + this.actorContext.log.info({ msg: "scheduling event", ...newEvent }); - /** - * Puts an actor to sleep. This should just start the sleep sequence, most shutdown logic should be in _stop (which is called by the ActorDriver when sleeping). - * - * For the engine, this will: - * 1. Publish EventActorIntent with ActorIntentSleep (via driver.startSleep) - * 2. Engine runner will wait for CommandStopActor - * 3. Engine runner will call _onStop and wait for it to finish - * 4. Engine runner will publish EventActorStateUpdate with ActorStateSTop - **/ - _startSleep() { - if (this.#stopCalled) { - this.#rLog.debug({ - msg: "cannot call _startSleep if actor already stopping", - }); - return; + // Insert event in to index + const insertIndex = this.#persist.scheduledEvents.findIndex( + (x) => x.timestamp > newEvent.timestamp, + ); + if (insertIndex === -1) { + this.#persist.scheduledEvents.push(newEvent); + } else { + this.#persist.scheduledEvents.splice(insertIndex, 0, newEvent); } - // IMPORTANT: #sleepCalled should have no effect on the actor's - // behavior aside from preventing calling _startSleep twice. Wait for - // `_onStop` before putting in a stopping state. - if (this.#sleepCalled) { - this.#rLog.warn({ - msg: "cannot call _startSleep twice, actor already sleeping", + // Update alarm if: + // - this is the newest event (i.e. at beginning of array) or + // - this is the only event (i.e. the only event in the array) + if (insertIndex === 0 || this.#persist.scheduledEvents.length === 1) { + this.actorContext.log.info({ + msg: "setting alarm", + timestamp: newEvent.timestamp, + eventCount: this.#persist.scheduledEvents.length, }); - return; + await this.#queueSetAlarm(newEvent.timestamp); } - this.#sleepCalled = true; - - // NOTE: Publishes ActorIntentSleep - const sleep = this.#actorDriver.startSleep?.bind( - this.#actorDriver, - this.#actorId, - ); - invariant(this.#sleepingSupported, "sleeping not supported"); - invariant(sleep, "no sleep on driver"); - - this.#rLog.info({ msg: "actor sleeping" }); + } - // Schedule sleep to happen on the next tick. This allows for any action that calls _sleep to complete. - setImmediate(() => { - // The actor driver should call stop when ready to stop - // - // This will call _stop once Pegboard responds with the new status - sleep(); + async scheduleEvent( + timestamp: number, + action: string, + args: unknown[], + ): Promise { + return this.#scheduleEventInner({ + eventId: crypto.randomUUID(), + timestamp, + action, + args: bufferToArrayBuffer(cbor.encode(args)), }); } - // MARK: Stop /** - * For the engine: - * 1. Engine runner receives CommandStopActor - * 2. Engine runner calls _onStop and waits for it to finish - * 3. Engine runner publishes EventActorStateUpdate with ActorStateSTop + * Triggers any pending alarms. + * + * This method is idempotent. It's called automatically when the actor wakes + * in order to trigger any pending alarms. */ - async _onStop() { - if (this.#stopCalled) { - this.#rLog.warn({ msg: "already stopping actor" }); - return; - } - this.#stopCalled = true; + async _onAlarm() { + const now = Date.now(); + this.actorContext.log.debug({ + msg: "alarm triggered", + now, + events: this.#persist.scheduledEvents.length, + }); - this.#rLog.info({ msg: "actor stopping" }); + // Update sleep + // + // Do this before any async logic + this.#resetSleepTimer(); - if (this.#sleepTimeout) { - clearTimeout(this.#sleepTimeout); - this.#sleepTimeout = undefined; + // Remove events from schedule that we're about to run + const runIndex = this.#persist.scheduledEvents.findIndex( + (x) => x.timestamp <= now, + ); + if (runIndex === -1) { + // This method is idempotent, so this will happen in scenarios like `start` and + // no events are pending. + this.#rLog.debug({ msg: "no events are due yet" }); + if (this.#persist.scheduledEvents.length > 0) { + const nextTs = this.#persist.scheduledEvents[0].timestamp; + this.actorContext.log.debug({ + msg: "alarm fired early, rescheduling for next event", + now, + nextTs, + delta: nextTs - now, + }); + await this.#queueSetAlarm(nextTs); + } + this.actorContext.log.debug({ msg: "no events to run", now }); + return; } + const scheduleEvents = this.#persist.scheduledEvents.splice( + 0, + runIndex + 1, + ); + this.actorContext.log.debug({ + msg: "running events", + count: scheduleEvents.length, + }); - // Abort any listeners waiting for shutdown - try { - this.#abortController.abort(); - } catch {} + // Set alarm for next event + if (this.#persist.scheduledEvents.length > 0) { + const nextTs = this.#persist.scheduledEvents[0].timestamp; + this.actorContext.log.info({ + msg: "setting next alarm", + nextTs, + remainingEvents: this.#persist.scheduledEvents.length, + }); + await this.#queueSetAlarm(nextTs); + } - // Call onStop lifecycle hook if defined - if (this.#config.onStop) { + // Iterate by event key in order to ensure we call the events in order + for (const event of scheduleEvents) { try { - this.#rLog.debug({ msg: "calling onStop" }); - const result = this.#config.onStop(this.actorContext); - if (result instanceof Promise) { - await deadline(result, this.#config.options.onStopTimeout); - } - this.#rLog.debug({ msg: "onStop completed" }); - } catch (error) { - if (error instanceof DeadlineError) { - this.#rLog.error({ msg: "onStop timed out" }); - } else { - this.#rLog.error({ - msg: "error in onStop", + this.actorContext.log.info({ + msg: "running action for event", + event: event.eventId, + timestamp: event.timestamp, + action: event.action, + }); + + // Look up function + const fn: unknown = this.#config.actions[event.action]; + + if (!fn) + throw new Error(`Missing action for alarm ${event.action}`); + if (typeof fn !== "function") + throw new Error( + `Alarm function lookup for ${event.action} returned ${typeof fn}`, + ); + + // Call function + try { + const args = event.args + ? cbor.decode(new Uint8Array(event.args)) + : []; + await fn.call(undefined, this.actorContext, ...args); + } catch (error) { + this.actorContext.log.error({ + msg: "error while running event", error: stringifyError(error), + event: event.eventId, + timestamp: event.timestamp, + action: event.action, }); } - } - } - - const promises: Promise[] = []; - - // Disconnect existing non-hibernatable connections - for (const connection of this.#connections.values()) { - if (!connection.isHibernatable) { - this.#rLog.debug({ - msg: "disconnecting non-hibernatable connection on actor stop", - connId: connection.id, + } catch (error) { + this.actorContext.log.error({ + msg: "internal error while running event", + error: stringifyError(error), + ...event, }); - promises.push(connection.disconnect()); } - - // TODO: Figure out how to abort HTTP requests on shutdown. This - // might already be handled by the engine runner tunnel shutdown. - } - - // Wait for any background tasks to finish, with timeout - await this.#waitBackgroundPromises( - this.#config.options.waitUntilTimeout, - ); - - // Clear timeouts - if (this.#pendingSaveTimeout) clearTimeout(this.#pendingSaveTimeout); - - // Write state - await this.saveState({ immediate: true, allowStoppingState: true }); - - // Await all `close` event listeners with 1.5 second timeout - const res = Promise.race([ - Promise.all(promises).then(() => false), - new Promise((res) => - globalThis.setTimeout(() => res(true), 1500), - ), - ]); - - if (await res) { - this.#rLog.warn({ - msg: "timed out waiting for connections to close, shutting down anyway", - }); } - - // Wait for queues to finish - if (this.#persistWriteQueue.runningDrainLoop) - await this.#persistWriteQueue.runningDrainLoop; - if (this.#alarmWriteQueue.runningDrainLoop) - await this.#alarmWriteQueue.runningDrainLoop; } - /** Abort signal that fires when the actor is stopping. */ - get abortSignal(): AbortSignal { - return this.#abortController.signal; + async #queueSetAlarm(timestamp: number): Promise { + await this.#alarmWriteQueue.enqueue(async () => { + await this.#actorDriver.setAlarm(this, timestamp); + }); } + // MARK: Background Promises /** Wait for background waitUntil promises with a timeout. */ async #waitBackgroundPromises(timeoutMs: number) { const pending = this.#backgroundPromises; @@ -2093,99 +2118,100 @@ export class ActorInstance { } } + /** + * Prevents the actor from sleeping until promise is complete. + * + * This allows the actor runtime to ensure that a promise completes while + * returning from an action request early. + * + * @param promise - The promise to run in the background. + */ + _waitUntil(promise: Promise) { + this.#assertReady(); + + // TODO: Should we force save the state? + // Add logging to promise and make it non-failable + const nonfailablePromise = promise + .then(() => { + this.#rLog.debug({ msg: "wait until promise complete" }); + }) + .catch((error) => { + this.#rLog.error({ + msg: "wait until promise failed", + error: stringifyError(error), + }); + }); + this.#backgroundPromises.push(nonfailablePromise); + } + // MARK: BARE Conversion Helpers #convertToBarePersisted( persist: PersistedActor, - ): bareSchema.PersistedActor { - // Merge connections with hibernatableWebSocket data into hibernatableConns - const hibernatableConns: bareSchema.PersistedHibernatableConn[] = []; - - for (const conn of persist.connections) { - if (conn.hibernatableRequestId) { - // Find matching hibernatable WebSocket - const ws = persist.hibernatableWebSocket.find((ws) => - arrayBuffersEqual( - ws.requestId, - conn.hibernatableRequestId!, - ), - ); - - if (ws) { - hibernatableConns.push({ - id: conn.connId, - parameters: bufferToArrayBuffer( - cbor.encode(conn.params || {}), - ), - state: bufferToArrayBuffer( - cbor.encode(conn.state || {}), - ), - hibernatableRequestId: conn.hibernatableRequestId, - lastSeenTimestamp: ws.lastSeenTimestamp, - msgIndex: ws.msgIndex, - }); - } - } - } + ): persistSchema.Actor { + // Convert hibernatable connections from the in-memory connections map + // Convert hibernatableConns from the persisted structure + const hibernatableConns: persistSchema.HibernatableConn[] = + persist.hibernatableConns.map((conn) => ({ + id: conn.id, + parameters: bufferToArrayBuffer( + cbor.encode(conn.parameters || {}), + ), + state: bufferToArrayBuffer(cbor.encode(conn.state || {})), + subscriptions: conn.subscriptions.map((sub) => ({ + eventName: sub.eventName, + })), + hibernatableRequestId: conn.hibernatableRequestId, + lastSeenTimestamp: BigInt(conn.lastSeenTimestamp), + msgIndex: BigInt(conn.msgIndex), + })); return { input: persist.input !== undefined ? bufferToArrayBuffer(cbor.encode(persist.input)) : null, - hasInitialized: persist.hasInitiated, + hasInitialized: persist.hasInitialized, state: bufferToArrayBuffer(cbor.encode(persist.state)), hibernatableConns, scheduledEvents: persist.scheduledEvents.map((event) => ({ eventId: event.eventId, timestamp: BigInt(event.timestamp), - action: event.kind.generic.actionName, - args: event.kind.generic.args ?? null, + action: event.action, + args: event.args ?? null, })), }; } #convertFromBarePersisted( - bareData: bareSchema.PersistedActor, + bareData: persistSchema.Actor, ): PersistedActor { - // Split hibernatableConns back into connections and hibernatableWebSocket - const connections: PersistedConn[] = []; - const hibernatableWebSocket: PersistedHibernatableWebSocket[] = []; - - for (const conn of bareData.hibernatableConns) { - connections.push({ - connId: conn.id, - params: cbor.decode(new Uint8Array(conn.parameters)), - state: cbor.decode(new Uint8Array(conn.state)), - subscriptions: [], - lastSeen: 0, // Will be set from lastSeenTimestamp + // Convert hibernatableConns from the BARE schema format + const hibernatableConns: PersistedHibernatableConn[] = + bareData.hibernatableConns.map((conn) => ({ + id: conn.id, + parameters: cbor.decode(new Uint8Array(conn.parameters)) as CP, + state: cbor.decode(new Uint8Array(conn.state)) as CS, + subscriptions: conn.subscriptions.map((sub) => ({ + eventName: sub.eventName, + })), hibernatableRequestId: conn.hibernatableRequestId, - }); - - hibernatableWebSocket.push({ - requestId: conn.hibernatableRequestId, - lastSeenTimestamp: conn.lastSeenTimestamp, - msgIndex: conn.msgIndex, - }); - } + lastSeenTimestamp: Number(conn.lastSeenTimestamp), + msgIndex: Number(conn.msgIndex), + })); return { input: bareData.input ? cbor.decode(new Uint8Array(bareData.input)) : undefined, - hasInitiated: bareData.hasInitialized, + hasInitialized: bareData.hasInitialized, state: cbor.decode(new Uint8Array(bareData.state)), - connections, + hibernatableConns, scheduledEvents: bareData.scheduledEvents.map((event) => ({ eventId: event.eventId, timestamp: Number(event.timestamp), - kind: { - generic: { - actionName: event.action, - args: event.args, - }, - }, + action: event.action, + args: event.args ?? undefined, })), - hibernatableWebSocket, }; } } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/kv.ts b/rivetkit-typescript/packages/rivetkit/src/actor/kv.ts new file mode 100644 index 0000000000..1865bf4c54 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/kv.ts @@ -0,0 +1,14 @@ +export const KEYS = { + PERSIST_DATA: Uint8Array.from([1]), + CONN_PREFIX: Uint8Array.from([2]), // Prefix for connection keys +}; + +// Helper to create a connection key +export function makeConnKey(connId: string): Uint8Array { + const encoder = new TextEncoder(); + const connIdBytes = encoder.encode(connId); + const key = new Uint8Array(KEYS.CONN_PREFIX.length + connIdBytes.length); + key.set(KEYS.CONN_PREFIX, 0); + key.set(connIdBytes, KEYS.CONN_PREFIX.length); + return key; +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/persisted.ts b/rivetkit-typescript/packages/rivetkit/src/actor/persisted.ts index e236b47e67..fee27efda2 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/persisted.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/persisted.ts @@ -1,14 +1,41 @@ +/** + * Persisted data structures matching actor-persist/v3.bare schema + */ + +/** Scheduled event to be executed at a specific timestamp */ +export interface PersistedScheduleEvent { + eventId: string; + timestamp: number; + action: string; + args?: ArrayBuffer; +} + +/** Connection associated with hibernatable WebSocket that should persist across lifecycles */ +export interface PersistedHibernatableConn { + /** Connection ID generated by RivetKit */ + id: string; + parameters: CP; + state: CS; + subscriptions: PersistedSubscription[]; + /** Request ID of the hibernatable WebSocket */ + hibernatableRequestId: ArrayBuffer; + /** Last seen message from this WebSocket */ + lastSeenTimestamp: number; + /** Last seen message index for this WebSocket */ + msgIndex: number; +} + /** State object that gets automatically persisted to storage. */ export interface PersistedActor { + /** Input data passed to the actor on initialization */ input?: I; - hasInitiated: boolean; + hasInitialized: boolean; state: S; - connections: PersistedConn[]; + hibernatableConns: PersistedHibernatableConn[]; scheduledEvents: PersistedScheduleEvent[]; - hibernatableWebSocket: PersistedHibernatableWebSocket[]; } -/** Object representing connection that gets persisted to storage. */ +/** Object representing connection that gets persisted to storage separately via KV. */ export interface PersistedConn { connId: string; params: CP; @@ -18,31 +45,10 @@ export interface PersistedConn { /** Last time the socket was seen. This is set when disconnected so we can determine when we need to clean this up. */ lastSeen: number; - /** Request ID of the hibernatable WebSocket. See PersistedActor.hibernatableWebSocket */ + /** Request ID of the hibernatable WebSocket. See PersistedActor.hibernatableConns */ hibernatableRequestId?: ArrayBuffer; } export interface PersistedSubscription { eventName: string; } - -export interface GenericPersistedScheduleEvent { - actionName: string; - args: ArrayBuffer | null; -} - -export type PersistedScheduleEventKind = { - generic: GenericPersistedScheduleEvent; -}; - -export interface PersistedScheduleEvent { - eventId: string; - timestamp: number; - kind: PersistedScheduleEventKind; -} - -export interface PersistedHibernatableWebSocket { - requestId: ArrayBuffer; - lastSeenTimestamp: bigint; - msgIndex: bigint; -} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts index 163592cc8e..b9d1bd8d4f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts @@ -158,9 +158,12 @@ export async function handleWebSocketConnect( // Check if this is a hibernatable websocket const isHibernatable = !!requestIdBuf && - actor[PERSIST_SYMBOL].hibernatableWebSocket.findIndex( - (ws) => - arrayBuffersEqual(ws.requestId, requestIdBuf), + actor[PERSIST_SYMBOL].hibernatableConns.findIndex( + (conn) => + arrayBuffersEqual( + conn.hibernatableRequestId, + requestIdBuf, + ), ) !== -1; conn = await actor.createConn( @@ -391,8 +394,11 @@ export async function handleRawWebSocketHandler( // Extract rivetRequestId provided by engine runner const rivetRequestId = evt?.rivetRequestId; const isHibernatable = - actor[PERSIST_SYMBOL].hibernatableWebSocket.findIndex((ws) => - arrayBuffersEqual(ws.requestId, rivetRequestId), + actor[PERSIST_SYMBOL].hibernatableConns.findIndex((conn) => + arrayBuffersEqual( + conn.hibernatableRequestId, + rivetRequestId, + ), ) !== -1; // Wrap the Hono WebSocket in our adapter diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-helpers/utils.ts b/rivetkit-typescript/packages/rivetkit/src/driver-helpers/utils.ts index 9c875043e8..0936f816b2 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-helpers/utils.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-helpers/utils.ts @@ -1,12 +1,12 @@ import * as cbor from "cbor-x"; -import type * as schema from "@/schemas/actor-persist/mod"; -import { PERSISTED_ACTOR_VERSIONED } from "@/schemas/actor-persist/versioned"; +import type * as persistSchema from "@/schemas/actor-persist/mod"; +import { ACTOR_VERSIONED } from "@/schemas/actor-persist/versioned"; import { bufferToArrayBuffer } from "@/utils"; export function serializeEmptyPersistData( input: unknown | undefined, ): Uint8Array { - const persistData: schema.PersistedActor = { + const persistData: persistSchema.Actor = { input: input !== undefined ? bufferToArrayBuffer(cbor.encode(input)) @@ -16,5 +16,5 @@ export function serializeEmptyPersistData( hibernatableConns: [], scheduledEvents: [], }; - return PERSISTED_ACTOR_VERSIONED.serializeWithEmbeddedVersion(persistData); + return ACTOR_VERSIONED.serializeWithEmbeddedVersion(persistData); } diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts index 17152b4770..9bfdc507fb 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -12,6 +12,7 @@ import invariant from "invariant"; import { lookupInRegistry } from "@/actor/definition"; import { PERSIST_SYMBOL } from "@/actor/instance"; import { deserializeActorKey } from "@/actor/keys"; +import { KEYS } from "@/actor/kv"; import { EncodingSchema } from "@/actor/protocol/serde"; import { type ActorRouter, createActorRouter } from "@/actor/router"; import { @@ -49,7 +50,6 @@ import { setLongTimeout, stringifyError, } from "@/utils"; -import { KEYS } from "./kv"; import { logger } from "./log"; const RUNNER_SSE_PING_INTERVAL = 1000; @@ -57,7 +57,6 @@ const RUNNER_SSE_PING_INTERVAL = 1000; interface ActorHandler { actor?: AnyActorInstance; actorStartPromise?: ReturnType>; - persistedData?: Uint8Array; } export type DriverContext = {}; @@ -171,14 +170,14 @@ export class EngineActorDriver implements ActorDriver { // Check for existing WS const hibernatableArray = - handler.actor[PERSIST_SYMBOL].hibernatableWebSocket; + handler.actor[PERSIST_SYMBOL].hibernatableConns; logger().debug({ msg: "checking hibernatable websockets", requestId: idToStr(requestId), existingHibernatableWebSockets: hibernatableArray.length, }); - const existingWs = hibernatableArray.find((ws) => - arrayBuffersEqual(ws.requestId, requestId), + const existingWs = hibernatableArray.find((conn) => + arrayBuffersEqual(conn.hibernatableRequestId, requestId), ); // Determine configuration for new WS @@ -269,17 +268,16 @@ export class EngineActorDriver implements ActorDriver { msg: "updated existing hibernatable websocket timestamp", requestId: idToStr(requestId), }); - existingWs.lastSeenTimestamp = BigInt(Date.now()); - } else { + existingWs.lastSeenTimestamp = Date.now(); + } else if (path === PATH_CONNECT) { + // For new hibernatable connections, we'll create a placeholder entry + // The actual connection data will be populated when the connection is created logger().debug({ - msg: "created new hibernatable websocket entry", + msg: "will create hibernatable conn when connection is created", requestId: idToStr(requestId), }); - handler.actor[PERSIST_SYMBOL].hibernatableWebSocket.push({ - requestId, - lastSeenTimestamp: BigInt(Date.now()), - msgIndex: -1n, - }); + // Note: The actual hibernatable connection is created in instance.ts + // when createConn is called with a hibernatable requestId } return hibernationConfig; @@ -339,29 +337,6 @@ export class EngineActorDriver implements ActorDriver { return {}; } - async readPersistedData(actorId: string): Promise { - const handler = this.#actors.get(actorId); - if (!handler) throw new Error(`Actor ${actorId} not loaded`); - - // This was loaded during actor startup - return handler.persistedData; - } - - async writePersistedData(actorId: string, data: Uint8Array): Promise { - const handler = this.#actors.get(actorId); - if (!handler) throw new Error(`Actor ${actorId} not loaded`); - - handler.persistedData = data; - - logger().debug({ - msg: "writing persisted data for actor", - actorId, - dataSize: data.byteLength, - }); - - await this.#runner.kvPut(actorId, [[KEYS.PERSIST_DATA, data]]); - } - async setAlarm(actor: AnyActorInstance, timestamp: number): Promise { // Clear prev timeout if (this.#alarmTimeout) { @@ -392,6 +367,56 @@ export class EngineActorDriver implements ActorDriver { return undefined; } + // Batch KV operations + async kvBatchPut( + actorId: string, + entries: [Uint8Array, Uint8Array][], + ): Promise { + logger().debug({ + msg: "batch writing KV entries", + actorId, + entryCount: entries.length, + }); + + await this.#runner.kvPut(actorId, entries); + } + + async kvBatchGet( + actorId: string, + keys: Uint8Array[], + ): Promise<(Uint8Array | null)[]> { + logger().debug({ + msg: "batch reading KV entries", + actorId, + keyCount: keys.length, + }); + + return await this.#runner.kvGet(actorId, keys); + } + + async kvBatchDelete(actorId: string, keys: Uint8Array[]): Promise { + logger().debug({ + msg: "batch deleting KV entries", + actorId, + keyCount: keys.length, + }); + + await this.#runner.kvDelete(actorId, keys); + } + + async kvListPrefix( + actorId: string, + prefix: Uint8Array, + ): Promise<[Uint8Array, Uint8Array][]> { + logger().debug({ + msg: "listing KV entries with prefix", + actorId, + prefixLength: prefix.length, + }); + + return await this.#runner.kvListPrefix(actorId, prefix); + } + // Runner lifecycle callbacks async #runnerOnActorStart( actorId: string, @@ -420,26 +445,8 @@ export class EngineActorDriver implements ActorDriver { // create the same handler simultaneously. handler = { actorStartPromise: promiseWithResolvers(), - persistedData: undefined, }; this.#actors.set(actorId, handler); - - // Load persisted data from storage - const [persistedValue] = await this.#runner.kvGet(actorId, [ - KEYS.PERSIST_DATA, - ]); - - handler.persistedData = - persistedValue !== null - ? persistedValue - : serializeEmptyPersistData(input); - - logger().debug({ - msg: "loaded persisted data for actor", - actorId, - dataSize: handler.persistedData?.byteLength, - wasInStorage: persistedValue !== null, - }); } const name = actorConfig.name as string; diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/kv.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/kv.ts deleted file mode 100644 index 9ea919fca7..0000000000 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/kv.ts +++ /dev/null @@ -1,3 +0,0 @@ -export const KEYS = { - PERSIST_DATA: Uint8Array.from([1]), -}; diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/actor.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/actor.ts index 0ff7b47705..37989f220b 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/actor.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/actor.ts @@ -5,7 +5,6 @@ import type { ManagerDriver, } from "@/driver-helpers/mod"; import type { RegistryConfig, RunConfig } from "@/mod"; -import { bufferToArrayBuffer } from "@/utils"; import type { FileSystemGlobalState } from "./global-state"; export type ActorDriverContext = Record; @@ -55,20 +54,29 @@ export class FileSystemActorDriver implements ActorDriver { return {}; } - async readPersistedData(actorId: string): Promise { - return new Uint8Array( - (await this.#state.loadActorStateOrError(actorId)).persistedData, - ); + async kvBatchPut( + actorId: string, + entries: [Uint8Array, Uint8Array][], + ): Promise { + await this.#state.kvBatchPut(actorId, entries); + } + + async kvBatchGet( + actorId: string, + keys: Uint8Array[], + ): Promise<(Uint8Array | null)[]> { + return await this.#state.kvBatchGet(actorId, keys); } - async writePersistedData(actorId: string, data: Uint8Array): Promise { - const state = await this.#state.loadActorStateOrError(actorId); + async kvBatchDelete(actorId: string, keys: Uint8Array[]): Promise { + await this.#state.kvBatchDelete(actorId, keys); + } - // Save state to disk - await this.#state.writeActor(actorId, { - ...state, - persistedData: bufferToArrayBuffer(data), - }); + async kvListPrefix( + actorId: string, + prefix: Uint8Array, + ): Promise<[Uint8Array, Uint8Array][]> { + return await this.#state.kvListPrefix(actorId, prefix); } async setAlarm(actor: AnyActorInstance, timestamp: number): Promise { diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/global-state.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/global-state.ts index f1ac495853..016d3a489a 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/global-state.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/global-state.ts @@ -21,6 +21,7 @@ import { ACTOR_STATE_VERSIONED, } from "@/schemas/file-system-driver/versioned"; import { + arrayBuffersEqual, bufferToArrayBuffer, type LongTimeoutHandle, promiseWithResolvers, @@ -213,14 +214,22 @@ export class FileSystemGlobalState { } const entry = this.#upsertEntry(actorId); + + // Initialize kvStorage with the initial persist data + const kvStorage: schema.ActorKvEntry[] = []; + const persistData = serializeEmptyPersistData(input); + // Store under key [1] + kvStorage.push({ + key: bufferToArrayBuffer(new Uint8Array([1])), + value: bufferToArrayBuffer(persistData), + }); + entry.state = { actorId, name, key, createdAt: BigInt(Date.now()), - persistedData: bufferToArrayBuffer( - serializeEmptyPersistData(input), - ), + kvStorage, }; await this.writeActor(actorId, entry.state); return entry; @@ -292,14 +301,21 @@ export class FileSystemGlobalState { // If no state for this actor, then create & write state if (!entry.state) { + // Initialize kvStorage with the initial persist data + const kvStorage: schema.ActorKvEntry[] = []; + const persistData = serializeEmptyPersistData(input); + // Store under key [1] + kvStorage.push({ + key: bufferToArrayBuffer(new Uint8Array([1])), + value: bufferToArrayBuffer(persistData), + }); + entry.state = { actorId, name, key: key as readonly string[], createdAt: BigInt(Date.now()), - persistedData: bufferToArrayBuffer( - serializeEmptyPersistData(input), - ), + kvStorage, }; await this.writeActor(actorId, entry.state); } @@ -403,7 +419,7 @@ export class FileSystemGlobalState { name: state.name, key: state.key, createdAt: state.createdAt, - persistedData: state.persistedData, + kvStorage: state.kvStorage, }; // Perform atomic write @@ -716,4 +732,144 @@ export class FileSystemGlobalState { }); } } + + /** + * Batch put KV entries for an actor. + */ + async kvBatchPut( + actorId: string, + entries: [Uint8Array, Uint8Array][], + ): Promise { + const entry = await this.loadActor(actorId); + if (!entry.state) { + throw new Error(`Actor ${actorId} state not loaded`); + } + + // Create a mutable copy of kvStorage + const newKvStorage = [...entry.state.kvStorage]; + + // Update kvStorage with new entries + for (const [key, value] of entries) { + // Find existing entry with the same key + const existingIndex = newKvStorage.findIndex((e) => + arrayBuffersEqual(e.key, bufferToArrayBuffer(key)), + ); + + if (existingIndex >= 0) { + // Replace existing entry with new one + newKvStorage[existingIndex] = { + key: bufferToArrayBuffer(key), + value: bufferToArrayBuffer(value), + }; + } else { + // Add new entry + newKvStorage.push({ + key: bufferToArrayBuffer(key), + value: bufferToArrayBuffer(value), + }); + } + } + + // Update state with new kvStorage + entry.state = { + ...entry.state, + kvStorage: newKvStorage, + }; + + // Save state to disk + await this.writeActor(actorId, entry.state); + } + + /** + * Batch get KV entries for an actor. + */ + async kvBatchGet( + actorId: string, + keys: Uint8Array[], + ): Promise<(Uint8Array | null)[]> { + const entry = await this.loadActor(actorId); + if (!entry.state) { + throw new Error(`Actor ${actorId} state not loaded`); + } + + const results: (Uint8Array | null)[] = []; + for (const key of keys) { + // Find entry with the same key + const foundEntry = entry.state.kvStorage.find((e) => + arrayBuffersEqual(e.key, bufferToArrayBuffer(key)), + ); + + if (foundEntry) { + results.push(new Uint8Array(foundEntry.value)); + } else { + results.push(null); + } + } + return results; + } + + /** + * Batch delete KV entries for an actor. + */ + async kvBatchDelete(actorId: string, keys: Uint8Array[]): Promise { + const entry = await this.loadActor(actorId); + if (!entry.state) { + throw new Error(`Actor ${actorId} state not loaded`); + } + + // Create a mutable copy of kvStorage + const newKvStorage = [...entry.state.kvStorage]; + + // Delete entries from kvStorage + for (const key of keys) { + const indexToDelete = newKvStorage.findIndex((e) => + arrayBuffersEqual(e.key, bufferToArrayBuffer(key)), + ); + + if (indexToDelete >= 0) { + newKvStorage.splice(indexToDelete, 1); + } + } + + // Update state with new kvStorage + entry.state = { + ...entry.state, + kvStorage: newKvStorage, + }; + + // Save state to disk + await this.writeActor(actorId, entry.state); + } + + /** + * List KV entries with a given prefix for an actor. + */ + async kvListPrefix( + actorId: string, + prefix: Uint8Array, + ): Promise<[Uint8Array, Uint8Array][]> { + const entry = await this.loadActor(actorId); + if (!entry.state) { + throw new Error(`Actor ${actorId} state not loaded`); + } + + const results: [Uint8Array, Uint8Array][] = []; + for (const kvEntry of entry.state.kvStorage) { + const keyBytes = new Uint8Array(kvEntry.key); + // Check if key starts with prefix + if (keyBytes.length >= prefix.length) { + let hasPrefix = true; + for (let i = 0; i < prefix.length; i++) { + if (keyBytes[i] !== prefix[i]) { + hasPrefix = false; + break; + } + } + if (hasPrefix) { + results.push([keyBytes, new Uint8Array(kvEntry.value)]); + } + } + } + return results; + } } diff --git a/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts b/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts index 84ba566473..068c7fc5b6 100644 --- a/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts +++ b/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts @@ -8,10 +8,6 @@ import * as v3 from "../../../dist/schemas/actor-persist/v3"; export const CURRENT_VERSION = 3; -export type CurrentPersistedActor = v3.PersistedActor; -export type CurrentPersistedHibernatableConn = v3.PersistedHibernatableConn; -export type CurrentPersistedScheduleEvent = v3.PersistedScheduleEvent; - const migrations = new Map>([ [ 1, @@ -26,9 +22,9 @@ const migrations = new Map>([ ], [ 2, - (v2Data: v2.PersistedActor): v3.PersistedActor => { + (v2Data: v2.PersistedActor): v3.Actor => { // Merge connections and hibernatableWebSocket into hibernatableConns - const hibernatableConns: v3.PersistedHibernatableConn[] = []; + const hibernatableConns: v3.HibernatableConn[] = []; // Convert connections with hibernatable request IDs to hibernatable conns for (const conn of v2Data.connections) { @@ -45,6 +41,9 @@ const migrations = new Map>([ id: conn.id, parameters: conn.parameters, state: conn.state, + subscriptions: conn.subscriptions.map((sub) => ({ + eventName: sub.eventName, + })), hibernatableRequestId: conn.hibernatableRequestId, lastSeenTimestamp: ws.lastSeenTimestamp, msgIndex: ws.msgIndex, @@ -54,7 +53,7 @@ const migrations = new Map>([ } // Transform scheduled events from nested structure to flat structure - const scheduledEvents: v3.PersistedScheduleEvent[] = + const scheduledEvents: v3.ScheduleEvent[] = v2Data.scheduledEvents.map((event) => { // Extract action and args from the kind wrapper if (event.kind.tag === "GenericPersistedScheduleEvent") { @@ -82,10 +81,9 @@ const migrations = new Map>([ ], ]); -export const PERSISTED_ACTOR_VERSIONED = - createVersionedDataHandler({ - currentVersion: CURRENT_VERSION, - migrations, - serializeVersion: (data) => v3.encodePersistedActor(data), - deserializeVersion: (bytes) => v3.decodePersistedActor(bytes), - }); +export const ACTOR_VERSIONED = createVersionedDataHandler({ + currentVersion: CURRENT_VERSION, + migrations, + serializeVersion: (data) => v3.encodeActor(data), + deserializeVersion: (bytes) => v3.decodeActor(bytes), +}); diff --git a/rivetkit-typescript/packages/rivetkit/src/schemas/client-protocol/versioned.ts b/rivetkit-typescript/packages/rivetkit/src/schemas/client-protocol/versioned.ts index 51d5da5347..770238cc8a 100644 --- a/rivetkit-typescript/packages/rivetkit/src/schemas/client-protocol/versioned.ts +++ b/rivetkit-typescript/packages/rivetkit/src/schemas/client-protocol/versioned.ts @@ -7,14 +7,6 @@ import * as v2 from "../../../dist/schemas/client-protocol/v2"; export const CURRENT_VERSION = 2; -export type CurrentToServer = v2.ToServer; -export type CurrentToClient = v2.ToClient; -export type CurrentHttpActionRequest = v2.HttpActionRequest; -export type CurrentHttpActionResponse = v2.HttpActionResponse; -export type CurrentHttpResponseError = v2.HttpResponseError; -export type CurrentHttpResolveRequest = v2.HttpResolveRequest; -export type CurrentHttpResolveResponse = v2.HttpResolveResponse; - const migrations = new Map>(); // Migration from v1 to v2: Remove connectionToken from Init message diff --git a/rivetkit-typescript/packages/rivetkit/src/schemas/file-system-driver/mod.ts b/rivetkit-typescript/packages/rivetkit/src/schemas/file-system-driver/mod.ts index c7431b7f39..1ecf054619 100644 --- a/rivetkit-typescript/packages/rivetkit/src/schemas/file-system-driver/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/schemas/file-system-driver/mod.ts @@ -1 +1 @@ -export * from "../../../dist/schemas/file-system-driver/v1"; +export * from "../../../dist/schemas/file-system-driver/v2"; diff --git a/rivetkit-typescript/packages/rivetkit/src/schemas/file-system-driver/versioned.ts b/rivetkit-typescript/packages/rivetkit/src/schemas/file-system-driver/versioned.ts index c046584581..d013644603 100644 --- a/rivetkit-typescript/packages/rivetkit/src/schemas/file-system-driver/versioned.ts +++ b/rivetkit-typescript/packages/rivetkit/src/schemas/file-system-driver/versioned.ts @@ -2,27 +2,50 @@ import { createVersionedDataHandler, type MigrationFn, } from "@/common/versioned-data"; -import * as v1 from "../../../dist/schemas/file-system-driver/v1"; +import { bufferToArrayBuffer } from "@/utils"; +import type * as v1 from "../../../dist/schemas/file-system-driver/v1"; +import * as v2 from "../../../dist/schemas/file-system-driver/v2"; -export const CURRENT_VERSION = 1; +export const CURRENT_VERSION = 2; -export type CurrentActorState = v1.ActorState; -export type CurrentActorAlarm = v1.ActorAlarm; +const migrations = new Map>([ + [ + 2, + (v1State: v1.ActorState): v2.ActorState => { + // Create a new kvStorage list with the legacy persist data + const kvStorage: v2.ActorKvEntry[] = []; -const migrations = new Map>(); + // Store the legacy persist data under key [1] + if (v1State.persistedData) { + // Key [1] as Uint8Array + const key = new Uint8Array([1]); + kvStorage.push({ + key: bufferToArrayBuffer(key), + value: v1State.persistedData, + }); + } -export const ACTOR_STATE_VERSIONED = - createVersionedDataHandler({ - currentVersion: CURRENT_VERSION, - migrations, - serializeVersion: (data) => v1.encodeActorState(data), - deserializeVersion: (bytes) => v1.decodeActorState(bytes), - }); + return { + actorId: v1State.actorId, + name: v1State.name, + key: v1State.key, + kvStorage, + createdAt: v1State.createdAt, + }; + }, + ], +]); -export const ACTOR_ALARM_VERSIONED = - createVersionedDataHandler({ - currentVersion: CURRENT_VERSION, - migrations, - serializeVersion: (data) => v1.encodeActorAlarm(data), - deserializeVersion: (bytes) => v1.decodeActorAlarm(bytes), - }); +export const ACTOR_STATE_VERSIONED = createVersionedDataHandler({ + currentVersion: CURRENT_VERSION, + migrations, + serializeVersion: (data) => v2.encodeActorState(data), + deserializeVersion: (bytes) => v2.decodeActorState(bytes), +}); + +export const ACTOR_ALARM_VERSIONED = createVersionedDataHandler({ + currentVersion: CURRENT_VERSION, + migrations, + serializeVersion: (data) => v2.encodeActorAlarm(data), + deserializeVersion: (bytes) => v2.decodeActorAlarm(bytes), +});