Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 73 additions & 26 deletions rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import invariant from "invariant";
import type {
ActorKey,
ActorRouter,
AnyActorInstance as CoreAnyActorInstance,
RegistryConfig,
RunConfig,
Expand Down Expand Up @@ -31,8 +33,8 @@ export class CloudflareDurableObjectGlobalState {
// Map of actor ID -> DO state
#dos: Map<string, DurableObjectGlobalState> = new Map();

// Map of DO ID -> ActorHandler
#actors: Map<string, ActorHandler> = new Map();
// WeakMap of DO state -> ActorGlobalState for proper GC
#actors: WeakMap<DurableObjectState, ActorGlobalState> = new WeakMap();

getDOState(doId: string): DurableObjectGlobalState {
const state = this.#dos.get(doId);
Expand All @@ -47,20 +49,40 @@ export class CloudflareDurableObjectGlobalState {
this.#dos.set(doId, state);
}

get actors() {
return this.#actors;
getActorState(ctx: DurableObjectState): ActorGlobalState | undefined {
return this.#actors.get(ctx);
}

setActorState(ctx: DurableObjectState, actorState: ActorGlobalState): void {
this.#actors.set(ctx, actorState);
}
}

export interface DriverContext {
state: DurableObjectState;
}

// Actor handler to track running instances
class ActorHandler {
actor?: AnyActorInstance;
actorPromise?: ReturnType<typeof promiseWithResolvers<void>> =
promiseWithResolvers();
interface InitializedData {
name: string;
key: ActorKey;
generation: number;
}

interface LoadedActor {
actorRouter: ActorRouter;
actorDriver: ActorDriver;
generation: number;
}

// Actor global state to track running instances
export class ActorGlobalState {
// Initialization state
initialized?: InitializedData;

// Loaded actor state
actor?: LoadedActor;
actorInstance?: AnyActorInstance;
actorPromise?: ReturnType<typeof promiseWithResolvers<void>>;

/**
* Indicates if `startDestroy` has been called.
Expand All @@ -70,6 +92,14 @@ class ActorHandler {
* See the corresponding `destroyed` property in SQLite metadata.
*/
destroying: boolean = false;

reset() {
this.initialized = undefined;
this.actor = undefined;
this.actorInstance = undefined;
this.actorPromise = undefined;
this.destroying = false;
}
}

export class CloudflareActorsActorDriver implements ActorDriver {
Expand Down Expand Up @@ -103,20 +133,24 @@ export class CloudflareActorsActorDriver implements ActorDriver {
// Parse actor ID to get DO ID and generation
const [doId, expectedGeneration] = parseActorId(actorId);

// Get the DO state
const doState = this.#globalState.getDOState(doId);

// Check if actor is already loaded
let handler = this.#globalState.actors.get(doId);
if (handler) {
if (handler.actorPromise) await handler.actorPromise.promise;
if (!handler.actor) throw new Error("Actor should be loaded");
return handler.actor;
let handler = this.#globalState.getActorState(doState.ctx);
if (handler && handler.actorInstance) {
// Actor is already loaded, return it
return handler.actorInstance;
}

// Create new actor handler
handler = new ActorHandler();
this.#globalState.actors.set(doId, handler);
// Create new actor handler if it doesn't exist
if (!handler) {
handler = new ActorGlobalState();
handler.actorPromise = promiseWithResolvers();
this.#globalState.setActorState(doState.ctx, handler);
}

// Get the actor metadata from Durable Object storage
const doState = this.#globalState.getDOState(doId);
const sql = doState.ctx.storage.sql;

// Load actor metadata from SQL table
Expand Down Expand Up @@ -150,10 +184,10 @@ export class CloudflareActorsActorDriver implements ActorDriver {

// Create actor instance
const definition = lookupInRegistry(this.#registryConfig, name);
handler.actor = definition.instantiate();
handler.actorInstance = definition.instantiate();

// Start actor
await handler.actor.start(
await handler.actorInstance.start(
this,
this.#inlineClient,
actorId,
Expand All @@ -166,7 +200,7 @@ export class CloudflareActorsActorDriver implements ActorDriver {
handler.actorPromise?.resolve();
handler.actorPromise = undefined;

return handler.actor;
return handler.actorInstance;
}

getContext(actorId: string): DriverContext {
Expand Down Expand Up @@ -231,10 +265,12 @@ export class CloudflareActorsActorDriver implements ActorDriver {
// Parse actor ID to get DO ID and generation
const [doId, generation] = parseActorId(actorId);

const handler = this.#globalState.actors.get(doId);
// Get the DO state
const doState = this.#globalState.getDOState(doId);
const handler = this.#globalState.getActorState(doState.ctx);

// Actor not loaded, nothing to destroy
if (!handler || !handler.actor) {
if (!handler || !handler.actorInstance) {
return;
}

Expand All @@ -244,8 +280,17 @@ export class CloudflareActorsActorDriver implements ActorDriver {
}
handler.destroying = true;

// Spawn onStop
this.#callOnStopAsync(actorId, doId, handler.actorInstance);
}

async #callOnStopAsync(
actorId: string,
doId: string,
actor: CoreAnyActorInstance,
) {
// Stop
handler.actor.onStop("destroy");
await actor.onStop("destroy");

// Remove state
const doState = this.#globalState.getDOState(doId);
Expand All @@ -254,15 +299,17 @@ export class CloudflareActorsActorDriver implements ActorDriver {
sql.exec("DELETE FROM _rivetkit_kv_storage");

// Clear any scheduled alarms
doState.ctx.storage.deleteAlarm();
await doState.ctx.storage.deleteAlarm();

// Delete from ACTOR_KV in the background - use full actorId including generation
const env = getCloudflareAmbientEnv();
doState.ctx.waitUntil(
env.ACTOR_KV.delete(GLOBAL_KV_KEYS.actorMetadata(actorId)),
);

this.#globalState.actors.delete(doId);
// Reset global state using the DO context
const actorHandle = this.#globalState.getActorState(doState.ctx);
actorHandle?.reset();
}
}

Expand Down
Loading
Loading