Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 93 additions & 8 deletions rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,106 @@ export class CloudflareActorsActorDriver implements ActorDriver {
return { state: state.ctx };
}

async readPersistedData(actorId: string): Promise<Uint8Array | undefined> {
return await this.#getDOCtx(actorId).storage.get(KEYS.PERSIST_DATA);
}

async writePersistedData(actorId: string, data: Uint8Array): Promise<void> {
await this.#getDOCtx(actorId).storage.put(KEYS.PERSIST_DATA, data);
}

async setAlarm(actor: AnyActorInstance, timestamp: number): Promise<void> {
await this.#getDOCtx(actor.id).storage.setAlarm(timestamp);
}

async getDatabase(actorId: string): Promise<unknown | undefined> {
return this.#getDOCtx(actorId).storage.sql;
}

// Batch KV operations - convert between Uint8Array and Cloudflare's string-based API
async kvBatchPut(
actorId: string,
entries: [Uint8Array, Uint8Array][],
): Promise<void> {
const storage = this.#getDOCtx(actorId).storage;
const encoder = new TextDecoder();

// Convert Uint8Array entries to object for Cloudflare batch put
const storageObj: Record<string, Uint8Array> = {};
for (const [key, value] of entries) {
// Convert key from Uint8Array to string
const keyStr = this.#uint8ArrayToKey(key);
storageObj[keyStr] = value;
}

await storage.put(storageObj);
}

async kvBatchGet(
actorId: string,
keys: Uint8Array[],
): Promise<(Uint8Array | null)[]> {
const storage = this.#getDOCtx(actorId).storage;

// Convert keys to strings
const keyStrs = keys.map((k) => this.#uint8ArrayToKey(k));

// Get values from storage
const results = await storage.get<Uint8Array>(keyStrs);

// Convert Map results to array in same order as input keys
return keyStrs.map((k) => results.get(k) ?? null);
}

async kvBatchDelete(actorId: string, keys: Uint8Array[]): Promise<void> {
const storage = this.#getDOCtx(actorId).storage;

// Convert keys to strings
const keyStrs = keys.map((k) => this.#uint8ArrayToKey(k));

await storage.delete(keyStrs);
}

async kvListPrefix(
actorId: string,
prefix: Uint8Array,
): Promise<[Uint8Array, Uint8Array][]> {
const storage = this.#getDOCtx(actorId).storage;

// Convert prefix to string
const prefixStr = this.#uint8ArrayToKey(prefix);

// List with prefix
const results = await storage.list<Uint8Array>({ prefix: prefixStr });

// Convert Map to array of [key, value] tuples
const entries: [Uint8Array, Uint8Array][] = [];
for (const [key, value] of results) {
entries.push([this.#keyToUint8Array(key), value]);
}

return entries;
}

// Helper to convert Uint8Array key to string for Cloudflare storage
#uint8ArrayToKey(key: Uint8Array): string {
// Check if this is a connection key (starts with [2])
if (key.length > 0 && key[0] === 2) {
// Connection key - extract connId
const connId = new TextDecoder().decode(key.slice(1));
return `${KEYS.CONN_PREFIX}${connId}`;
}
// Otherwise, treat as persist data key [1]
return KEYS.PERSIST_DATA;
}

// Helper to convert string key back to Uint8Array
#keyToUint8Array(key: string): Uint8Array {
if (key.startsWith(KEYS.CONN_PREFIX)) {
// Connection key
const connId = key.slice(KEYS.CONN_PREFIX.length);
const encoder = new TextEncoder();
const connIdBytes = encoder.encode(connId);
const result = new Uint8Array(1 + connIdBytes.length);
result[0] = 2; // Connection prefix
result.set(connIdBytes, 1);
return result;
}
// Persist data key
return Uint8Array.from([1]);
}
}

export function createCloudflareActorsActorDriverBuilder(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@ export const KEYS = {
NAME: "rivetkit:name",
KEY: "rivetkit:key",
PERSIST_DATA: "rivetkit:data",
CONN_PREFIX: "rivetkit:conn:",
};

// Helper to create a connection key for Cloudflare
export function makeCloudflareConnKey(connId: string): string {
return `${KEYS.CONN_PREFIX}${connId}`;
}

export interface ActorHandlerInterface extends DurableObject {
initialize(req: ActorInitRequest): Promise<void>;
}
Expand Down
2 changes: 1 addition & 1 deletion rivetkit-typescript/packages/rivetkit/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@
],
"scripts": {
"build": "tsup src/mod.ts src/client/mod.ts src/common/log.ts src/common/websocket.ts src/actor/errors.ts src/topologies/coordinate/mod.ts src/topologies/partition/mod.ts src/utils.ts src/driver-helpers/mod.ts src/driver-test-suite/mod.ts src/test/mod.ts src/inspector/mod.ts",
"build:schema": "./scripts/compile-bare.ts compile schemas/client-protocol/v1.bare -o dist/schemas/client-protocol/v1.ts && ./scripts/compile-bare.ts compile schemas/client-protocol/v2.bare -o dist/schemas/client-protocol/v2.ts && ./scripts/compile-bare.ts compile schemas/file-system-driver/v1.bare -o dist/schemas/file-system-driver/v1.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v1.bare -o dist/schemas/actor-persist/v1.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v2.bare -o dist/schemas/actor-persist/v2.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v3.bare -o dist/schemas/actor-persist/v3.ts",
"build:schema": "./scripts/compile-bare.ts compile schemas/client-protocol/v1.bare -o dist/schemas/client-protocol/v1.ts && ./scripts/compile-bare.ts compile schemas/client-protocol/v2.bare -o dist/schemas/client-protocol/v2.ts && ./scripts/compile-bare.ts compile schemas/file-system-driver/v1.bare -o dist/schemas/file-system-driver/v1.ts && ./scripts/compile-bare.ts compile schemas/file-system-driver/v2.bare -o dist/schemas/file-system-driver/v2.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v1.bare -o dist/schemas/actor-persist/v1.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v2.bare -o dist/schemas/actor-persist/v2.ts && ./scripts/compile-bare.ts compile schemas/actor-persist/v3.bare -o dist/schemas/actor-persist/v3.ts",
"check-types": "tsc --noEmit",
"test": "vitest run",
"test:watch": "vitest",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# MARK: Connection
type Subscription struct {
eventName: str
}

# Connection associated with hibernatable WebSocket that should persist across lifecycles.
type PersistedHibernatableConn struct {
type HibernatableConn struct {
# Connection ID generated by RivetKit
id: str
parameters: data
state: data
subscriptions: list<Subscription>

# Request ID of the hibernatable WebSocket
hibernatableRequestId: data
Expand All @@ -15,19 +20,19 @@ type PersistedHibernatableConn struct {
}

# MARK: Schedule Event
type PersistedScheduleEvent struct {
type ScheduleEvent struct {
eventId: str
timestamp: i64
action: str
args: optional<data>
}

# MARK: Actor
type PersistedActor struct {
type Actor struct {
# Input data passed to the actor on initialization
input: optional<data>
hasInitialized: bool
state: data
hibernatableConns: list<PersistedHibernatableConn>
scheduledEvents: list<PersistedScheduleEvent>
hibernatableConns: list<HibernatableConn>
scheduledEvents: list<ScheduleEvent>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# File System Driver Schema (v2)

# MARK: Actor State
type ActorKvEntry struct {
key: data
value: data
}

type ActorState struct {
actorId: str
name: str
key: list<str>
# KV storage map for actor and connection data
# Keys are strings (base64 encoded), values are byte arrays
kvStorage: list<ActorKvEntry>
createdAt: u64
}

# MARK: Actor Alarm
type ActorAlarm struct {
actorId: str
timestamp: uint
}
89 changes: 85 additions & 4 deletions rivetkit-typescript/packages/rivetkit/src/actor/conn.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import * as cbor from "cbor-x";
import onChange from "on-change";
import { isCborSerializable } from "@/common/utils";
import type * as protocol from "@/schemas/client-protocol/mod";
import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned";
import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils";
Expand Down Expand Up @@ -44,7 +46,13 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
* This will only be persisted if using hibernatable WebSockets. If not,
* this is just used to hole state.
*/
__persist: PersistedConn<CP, CS>;
__persist!: PersistedConn<CP, CS>;

/** Raw persist object without the proxy wrapper */
#persistRaw: PersistedConn<CP, CS>;

/** Track if this connection's state has changed */
#changed = false;

get __driverState(): ConnDriverState | undefined {
return this.__socket?.driverState;
Expand Down Expand Up @@ -103,9 +111,9 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
return false;
}
return (
this.#actor[PERSIST_SYMBOL].hibernatableWebSocket.findIndex((x) =>
this.#actor[PERSIST_SYMBOL].hibernatableConns.findIndex((conn) =>
arrayBuffersEqual(
x.requestId,
conn.hibernatableRequestId,
this.__persist.hibernatableRequestId!,
),
) > -1
Expand All @@ -131,7 +139,80 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
persist: PersistedConn<CP, CS>,
) {
this.#actor = actor;
this.__persist = persist;
this.#persistRaw = persist;
this.#setupPersistProxy(persist);
}

/**
* Sets up the proxy for connection persistence with change tracking
*/
#setupPersistProxy(persist: PersistedConn<CP, CS>) {
// If this can't be proxied, return raw value
if (persist === null || typeof persist !== "object") {
this.__persist = persist;
return;
}

// Listen for changes to the object
this.__persist = onChange(
persist,
(
path: string,
value: any,
_previousValue: any,
_applyData: any,
) => {
// Validate CBOR serializability for state changes
if (path.startsWith("state")) {
let invalidPath = "";
if (
!isCborSerializable(
value,
(invalidPathPart: string) => {
invalidPath = invalidPathPart;
},
"",
)
) {
throw new errors.InvalidStateType({
path: path + (invalidPath ? `.${invalidPath}` : ""),
});
}
}

this.#changed = true;
this.#actor.rLog.debug({
msg: "conn onChange triggered",
connId: this.id,
path,
});

// Notify actor that this connection has changed
this.#actor.__markConnChanged(this);
},
{ ignoreDetached: true },
);
}

/**
* Returns whether this connection has unsaved changes
*/
get hasChanges(): boolean {
return this.#changed;
}

/**
* Marks changes as saved
*/
markSaved() {
this.#changed = false;
}

/**
* Gets the raw persist data for serialization
*/
get persistRaw(): PersistedConn<CP, CS> {
return this.#persistRaw;
}

#validateStateEnabled() {
Expand Down
23 changes: 20 additions & 3 deletions rivetkit-typescript/packages/rivetkit/src/actor/driver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,27 @@ export interface ActorDriver {

getContext(actorId: string): unknown;

readPersistedData(actorId: string): Promise<Uint8Array | undefined>;
// Batch KV operations
/** Batch write multiple key-value pairs. Keys and values are Uint8Arrays. */
kvBatchPut(
actorId: string,
entries: [Uint8Array, Uint8Array][],
): Promise<void>;

/** ActorInstance ensure that only one instance of writePersistedData is called in parallel at a time. */
writePersistedData(actorId: string, data: Uint8Array): Promise<void>;
/** Batch read multiple keys. Returns null for keys that don't exist. */
kvBatchGet(
actorId: string,
keys: Uint8Array[],
): Promise<(Uint8Array | null)[]>;

/** Batch delete multiple keys. */
kvBatchDelete(actorId: string, keys: Uint8Array[]): Promise<void>;

/** List all keys with a given prefix. */
kvListPrefix(
actorId: string,
prefix: Uint8Array,
): Promise<[Uint8Array, Uint8Array][]>;

// Schedule
/** ActorInstance ensure that only one instance of setAlarm is called in parallel at a time. */
Expand Down
Loading
Loading