Skip to content

Commit 37bad6b

Browse files
NathanFlurryclaude
andcommitted
refactor(rivetkit): add state manager for conn with symbol-based access
- Created StateManager class for connection state management - Moved state proxying and change tracking to dedicated manager - Changed conn methods to use symbols for internal access: - actor, stateEnabled, persistRaw (getters) - hasChanges(), markSaved(), sendMessage() (methods) - Updated all references across codebase to use new symbols - Maintains backward compatibility for public API (send, disconnect, etc) - Consistent naming pattern with instance/mod.ts StateManager 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 4aa3109 commit 37bad6b

File tree

7 files changed

+218
-124
lines changed

7 files changed

+218
-124
lines changed

rivetkit-typescript/packages/cloudflare-workers/src/actor-driver.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ import type {
66
} from "rivetkit";
77
import { lookupInRegistry } from "rivetkit";
88
import type { Client } from "rivetkit/client";
9-
import {
10-
type ActorDriver,
11-
type AnyActorInstance,
12-
type ManagerDriver,
9+
import type {
10+
ActorDriver,
11+
AnyActorInstance,
12+
ManagerDriver,
1313
} from "rivetkit/driver-helpers";
1414
import { promiseWithResolvers } from "rivetkit/utils";
1515
import { KEYS } from "./actor-handler-do";
@@ -239,7 +239,6 @@ export class CloudflareActorsActorDriver implements ActorDriver {
239239
// Persist data key
240240
return Uint8Array.from([1]);
241241
}
242-
243242
}
244243

245244
export function createCloudflareActorsActorDriverBuilder(

rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts

Lines changed: 49 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
import * as cbor from "cbor-x";
2-
import onChange from "on-change";
3-
import { isCborSerializable } from "@/common/utils";
42
import type * as protocol from "@/schemas/client-protocol/mod";
53
import { TO_CLIENT_VERSIONED } from "@/schemas/client-protocol/versioned";
64
import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils";
75
import type { AnyDatabaseProvider } from "../database";
8-
import * as errors from "../errors";
96
import {
107
ACTOR_INSTANCE_PERSIST_SYMBOL,
118
type ActorInstance,
129
} from "../instance/mod";
1310
import type { PersistedConn } from "../instance/persisted";
1411
import { CachedSerializer } from "../protocol/serde";
1512
import type { ConnDriver } from "./driver";
13+
import { StateManager } from "./state-manager";
1614

1715
export function generateConnRequestId(): string {
1816
return crypto.randomUUID();
@@ -24,6 +22,12 @@ export type AnyConn = Conn<any, any, any, any, any, any>;
2422

2523
export const CONN_PERSIST_SYMBOL = Symbol("persist");
2624
export const CONN_DRIVER_SYMBOL = Symbol("driver");
25+
export const CONN_ACTOR_SYMBOL = Symbol("actor");
26+
export const CONN_STATE_ENABLED_SYMBOL = Symbol("stateEnabled");
27+
export const CONN_PERSIST_RAW_SYMBOL = Symbol("persistRaw");
28+
export const CONN_HAS_CHANGES_SYMBOL = Symbol("hasChanges");
29+
export const CONN_MARK_SAVED_SYMBOL = Symbol("markSaved");
30+
export const CONN_SEND_MESSAGE_SYMBOL = Symbol("sendMessage");
2731

2832
/**
2933
* Represents a client connection to a actor.
@@ -38,72 +42,66 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
3842
// TODO: Remove this cyclical reference
3943
#actor: ActorInstance<S, CP, CS, V, I, DB>;
4044

41-
/**
42-
* The proxied state that notifies of changes automatically.
43-
*
44-
* Any data that should be stored indefinitely should be held within this
45-
* object.
46-
*
47-
* This will only be persisted if using hibernatable WebSockets. If not,
48-
* this is just used to hole state.
49-
*/
50-
[CONN_PERSIST_SYMBOL]!: PersistedConn<CP, CS>;
51-
52-
/** Raw persist object without the proxy wrapper */
53-
#persistRaw: PersistedConn<CP, CS>;
54-
55-
/** Track if this connection's state has changed */
56-
#changed = false;
45+
// MARK: - Managers
46+
#stateManager!: StateManager<CP, CS>;
5747

5848
/**
5949
* If undefined, then nothing is connected to this.
6050
*/
6151
[CONN_DRIVER_SYMBOL]?: ConnDriver;
6252

63-
public get params(): CP {
64-
return this[CONN_PERSIST_SYMBOL].params;
53+
// MARK: - Public Getters
54+
55+
get [CONN_ACTOR_SYMBOL](): ActorInstance<S, CP, CS, V, I, DB> {
56+
return this.#actor;
6557
}
6658

67-
public get stateEnabled() {
68-
return this.#actor.connStateEnabled;
59+
get [CONN_PERSIST_SYMBOL](): PersistedConn<CP, CS> {
60+
return this.#stateManager.persist;
61+
}
62+
63+
get params(): CP {
64+
return this.#stateManager.params;
65+
}
66+
67+
get [CONN_STATE_ENABLED_SYMBOL](): boolean {
68+
return this.#stateManager.stateEnabled;
6969
}
7070

7171
/**
7272
* Gets the current state of the connection.
7373
*
7474
* Throws an error if the state is not enabled.
7575
*/
76-
public get state(): CS {
77-
this.#validateStateEnabled();
78-
if (!this[CONN_PERSIST_SYMBOL].state)
79-
throw new Error("state should exists");
80-
return this[CONN_PERSIST_SYMBOL].state;
76+
get state(): CS {
77+
return this.#stateManager.state;
8178
}
8279

8380
/**
8481
* Sets the state of the connection.
8582
*
8683
* Throws an error if the state is not enabled.
8784
*/
88-
public set state(value: CS) {
89-
this.#validateStateEnabled();
90-
this[CONN_PERSIST_SYMBOL].state = value;
85+
set state(value: CS) {
86+
this.#stateManager.state = value;
9187
}
9288

9389
/**
9490
* Unique identifier for the connection.
9591
*/
96-
public get id(): ConnId {
97-
return this[CONN_PERSIST_SYMBOL].connId;
92+
get id(): ConnId {
93+
return this.#stateManager.persist.connId;
9894
}
9995

10096
/**
10197
* @experimental
10298
*
10399
* If the underlying connection can hibernate.
104100
*/
105-
public get isHibernatable(): boolean {
106-
if (!this[CONN_PERSIST_SYMBOL].hibernatableRequestId) {
101+
get isHibernatable(): boolean {
102+
const hibernatableRequestId =
103+
this.#stateManager.persist.hibernatableRequestId;
104+
if (!hibernatableRequestId) {
107105
return false;
108106
}
109107
return (
@@ -112,7 +110,7 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
112110
].hibernatableConns.findIndex((conn: any) =>
113111
arrayBuffersEqual(
114112
conn.hibernatableRequestId,
115-
this[CONN_PERSIST_SYMBOL].hibernatableRequestId!,
113+
hibernatableRequestId,
116114
),
117115
) > -1
118116
);
@@ -121,8 +119,8 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
121119
/**
122120
* Timestamp of the last time the connection was seen, i.e. the last time the connection was active and checked for liveness.
123121
*/
124-
public get lastSeen(): number {
125-
return this[CONN_PERSIST_SYMBOL].lastSeen;
122+
get lastSeen(): number {
123+
return this.#stateManager.persist.lastSeen;
126124
}
127125

128126
/**
@@ -132,94 +130,37 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
132130
*
133131
* @protected
134132
*/
135-
public constructor(
133+
constructor(
136134
actor: ActorInstance<S, CP, CS, V, I, DB>,
137135
persist: PersistedConn<CP, CS>,
138136
) {
139137
this.#actor = actor;
140-
this.#persistRaw = persist;
141-
this.#setupPersistProxy(persist);
142-
}
143-
144-
/**
145-
* Sets up the proxy for connection persistence with change tracking
146-
*/
147-
#setupPersistProxy(persist: PersistedConn<CP, CS>) {
148-
// If this can't be proxied, return raw value
149-
if (persist === null || typeof persist !== "object") {
150-
this[CONN_PERSIST_SYMBOL] = persist;
151-
return;
152-
}
153-
154-
// Listen for changes to the object
155-
this[CONN_PERSIST_SYMBOL] = onChange(
156-
persist,
157-
(
158-
path: string,
159-
value: any,
160-
_previousValue: any,
161-
_applyData: any,
162-
) => {
163-
// Validate CBOR serializability for state changes
164-
if (path.startsWith("state")) {
165-
let invalidPath = "";
166-
if (
167-
!isCborSerializable(
168-
value,
169-
(invalidPathPart: string) => {
170-
invalidPath = invalidPathPart;
171-
},
172-
"",
173-
)
174-
) {
175-
throw new errors.InvalidStateType({
176-
path: path + (invalidPath ? `.${invalidPath}` : ""),
177-
});
178-
}
179-
}
180-
181-
this.#changed = true;
182-
this.#actor.rLog.debug({
183-
msg: "conn onChange triggered",
184-
connId: this.id,
185-
path,
186-
});
187-
188-
// Notify actor that this connection has changed
189-
this.#actor.markConnChanged(this);
190-
},
191-
{ ignoreDetached: true },
192-
);
138+
this.#stateManager = new StateManager(this);
139+
this.#stateManager.initPersistProxy(persist);
193140
}
194141

195142
/**
196143
* Returns whether this connection has unsaved changes
197144
*/
198-
get hasChanges(): boolean {
199-
return this.#changed;
145+
[CONN_HAS_CHANGES_SYMBOL](): boolean {
146+
return this.#stateManager.hasChanges();
200147
}
201148

202149
/**
203150
* Marks changes as saved
204151
*/
205-
markSaved() {
206-
this.#changed = false;
152+
[CONN_MARK_SAVED_SYMBOL]() {
153+
this.#stateManager.markSaved();
207154
}
208155

209156
/**
210157
* Gets the raw persist data for serialization
211158
*/
212-
get persistRaw(): PersistedConn<CP, CS> {
213-
return this.#persistRaw;
214-
}
215-
216-
#validateStateEnabled() {
217-
if (!this.stateEnabled) {
218-
throw new errors.ConnStateNotEnabled();
219-
}
159+
get [CONN_PERSIST_RAW_SYMBOL](): PersistedConn<CP, CS> {
160+
return this.#stateManager.persistRaw;
220161
}
221162

222-
public sendMessage(message: CachedSerializer<protocol.ToClient>) {
163+
[CONN_SEND_MESSAGE_SYMBOL](message: CachedSerializer<protocol.ToClient>) {
223164
if (this[CONN_DRIVER_SYMBOL]) {
224165
const driver = this[CONN_DRIVER_SYMBOL];
225166
if (driver.sendMessage) {
@@ -245,14 +186,14 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
245186
* @param args - The arguments for the event.
246187
* @see {@link https://rivet.dev/docs/events|Events Documentation}
247188
*/
248-
public send(eventName: string, ...args: unknown[]) {
189+
send(eventName: string, ...args: unknown[]) {
249190
this.#actor.inspector.emitter.emit("eventFired", {
250191
type: "event",
251192
eventName,
252193
args,
253194
connId: this.id,
254195
});
255-
this.sendMessage(
196+
this[CONN_SEND_MESSAGE_SYMBOL](
256197
new CachedSerializer<protocol.ToClient>(
257198
{
258199
body: {
@@ -273,7 +214,7 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
273214
*
274215
* @param reason - The reason for disconnection.
275216
*/
276-
public async disconnect(reason?: string) {
217+
async disconnect(reason?: string) {
277218
if (this[CONN_DRIVER_SYMBOL]) {
278219
const driver = this[CONN_DRIVER_SYMBOL];
279220
if (driver.disconnect) {

0 commit comments

Comments
 (0)