Skip to content

Commit dba206d

Browse files
committed
fix(rivetkit): fix race condition with websocket open events
1 parent 4599c4a commit dba206d

File tree

15 files changed

+598
-625
lines changed

15 files changed

+598
-625
lines changed

rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-websocket.ts

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,17 @@ export const rawWebSocketActor = actor({
5151
}),
5252
);
5353
} else if (parsed.type === "getRequestInfo") {
54-
throw "TODO";
55-
// Send back the request URL info
56-
// websocket.send(
57-
// JSON.stringify({
58-
// type: "requestInfo",
59-
// url: opts.request.url,
60-
// pathname: new URL(opts.request.url).pathname,
61-
// search: new URL(opts.request.url).search,
62-
// }),
63-
// );
54+
// Send back the request URL info if available
55+
const url = ctx.request?.url || "ws://actor/websocket";
56+
const urlObj = new URL(url);
57+
websocket.send(
58+
JSON.stringify({
59+
type: "requestInfo",
60+
url: url,
61+
pathname: urlObj.pathname,
62+
search: urlObj.search,
63+
}),
64+
);
6465
} else {
6566
// Echo back
6667
websocket.send(data);

rivetkit-typescript/packages/rivetkit/src/actor/config.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,14 @@ interface BaseActorConfig<
330330
* @returns Void or a Promise that resolves when connection handling is complete
331331
*/
332332
onConnect?: (
333-
c: OnConnectContext<TState, TVars, TInput, TDatabase>,
333+
c: OnConnectContext<
334+
TState,
335+
TConnParams,
336+
TConnState,
337+
TVars,
338+
TInput,
339+
TDatabase
340+
>,
334341
conn: Conn<TState, TConnParams, TConnState, TVars, TInput, TDatabase>,
335342
) => void | Promise<void>;
336343

rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/raw-websocket.ts

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import type { AnyConn } from "@/actor/conn/mod";
22
import type { AnyActorInstance } from "@/actor/instance/mod";
33
import type { UniversalWebSocket } from "@/common/websocket-interface";
4-
import type { ConnDriver, DriverReadyState } from "../driver";
4+
import { loggerWithoutContext } from "../../log";
5+
import { type ConnDriver, DriverReadyState } from "../driver";
56

67
/**
78
* Creates a raw WebSocket connection driver.
@@ -15,10 +16,11 @@ export function createRawWebSocketSocket(
1516
requestId: string,
1617
requestIdBuf: ArrayBuffer | undefined,
1718
hibernatable: boolean,
18-
websocket: UniversalWebSocket,
1919
closePromise: Promise<void>,
20-
): ConnDriver {
21-
return {
20+
): { driver: ConnDriver; setWebSocket(ws: UniversalWebSocket): void } {
21+
let websocket: UniversalWebSocket | undefined;
22+
23+
const driver: ConnDriver = {
2224
type: "raw-websocket",
2325
requestId,
2426
requestIdBuf,
@@ -32,6 +34,13 @@ export function createRawWebSocketSocket(
3234
_conn: AnyConn,
3335
reason?: string,
3436
) => {
37+
if (!websocket) {
38+
loggerWithoutContext().warn(
39+
"disconnecting raw ws without websocket",
40+
);
41+
return;
42+
}
43+
3544
// Close socket
3645
websocket.close(1000, reason);
3746

@@ -40,14 +49,21 @@ export function createRawWebSocketSocket(
4049
},
4150

4251
terminate: () => {
43-
(websocket as any).terminate?.();
52+
(websocket as any)?.terminate?.();
4453
},
4554

4655
getConnectionReadyState: (
4756
_actor: AnyActorInstance,
4857
_conn: AnyConn,
4958
): DriverReadyState | undefined => {
50-
return websocket.readyState;
59+
return websocket?.readyState ?? DriverReadyState.CONNECTING;
60+
},
61+
};
62+
63+
return {
64+
driver,
65+
setWebSocket(ws) {
66+
websocket = ws;
5167
},
5268
};
5369
}

rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import type { AnyConn } from "@/actor/conn/mod";
33
import type { AnyActorInstance } from "@/actor/instance/mod";
44
import type { CachedSerializer, Encoding } from "@/actor/protocol/serde";
55
import type * as protocol from "@/schemas/client-protocol/mod";
6+
import { loggerWithoutContext } from "../../log";
67
import { type ConnDriver, DriverReadyState } from "../driver";
78

89
export type ConnDriverWebSocketState = Record<never, never>;
@@ -12,10 +13,12 @@ export function createWebSocketSocket(
1213
requestIdBuf: ArrayBuffer | undefined,
1314
hibernatable: boolean,
1415
encoding: Encoding,
15-
websocket: WSContext,
1616
closePromise: Promise<void>,
17-
): ConnDriver {
18-
return {
17+
): { driver: ConnDriver; setWebSocket(ws: WSContext): void } {
18+
// Wait for WS to open
19+
let websocket: WSContext | undefined;
20+
21+
const driver: ConnDriver = {
1922
type: "websocket",
2023
requestId,
2124
requestIdBuf,
@@ -25,6 +28,13 @@ export function createWebSocketSocket(
2528
conn: AnyConn,
2629
message: CachedSerializer<any, any, any>,
2730
) => {
31+
if (!websocket) {
32+
actor.rLog.warn({
33+
msg: "websocket not open",
34+
connId: conn.id,
35+
});
36+
return;
37+
}
2838
if (websocket.readyState !== DriverReadyState.OPEN) {
2939
actor.rLog.warn({
3040
msg: "attempting to send message to closed websocket, this is likely a bug in RivetKit",
@@ -83,6 +93,13 @@ export function createWebSocketSocket(
8393
_conn: AnyConn,
8494
reason?: string,
8595
) => {
96+
if (!websocket) {
97+
loggerWithoutContext().warn(
98+
"disconnecting ws without websocket",
99+
);
100+
return;
101+
}
102+
86103
// Close socket
87104
websocket.close(1000, reason);
88105

@@ -98,7 +115,14 @@ export function createWebSocketSocket(
98115
_actor: AnyActorInstance,
99116
_conn: AnyConn,
100117
): DriverReadyState | undefined => {
101-
return websocket.readyState;
118+
return websocket?.readyState ?? DriverReadyState.CONNECTING;
119+
},
120+
};
121+
122+
return {
123+
driver,
124+
setWebSocket(ws) {
125+
websocket = ws;
102126
},
103127
};
104128
}

rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@ import {
77
} from "@/schemas/client-protocol-zod/mod";
88
import { arrayBuffersEqual, bufferToArrayBuffer } from "@/utils";
99
import type { AnyDatabaseProvider } from "../database";
10-
import {
11-
ACTOR_INSTANCE_PERSIST_SYMBOL,
12-
type ActorInstance,
13-
} from "../instance/mod";
10+
import { InternalError } from "../errors";
11+
import type { ActorInstance } from "../instance/mod";
1412
import type { PersistedConn } from "../instance/persisted";
1513
import { CachedSerializer } from "../protocol/serde";
1614
import type { ConnDriver } from "./driver";
@@ -24,6 +22,7 @@ export type ConnId = string;
2422

2523
export type AnyConn = Conn<any, any, any, any, any, any>;
2624

25+
export const CONN_CONNECTED_SYMBOL = Symbol("connected");
2726
export const CONN_PERSIST_SYMBOL = Symbol("persist");
2827
export const CONN_DRIVER_SYMBOL = Symbol("driver");
2928
export const CONN_ACTOR_SYMBOL = Symbol("actor");
@@ -60,6 +59,16 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
6059
return this.#actor;
6160
}
6261

62+
/** Connections exist before being connected to an actor. If true, this connection has been connected. */
63+
[CONN_CONNECTED_SYMBOL] = false;
64+
65+
#assertConnected() {
66+
if (!this[CONN_CONNECTED_SYMBOL])
67+
throw new InternalError(
68+
"Connection not connected yet. This happens when trying to use the connection in onBeforeConnect or createConnState.",
69+
);
70+
}
71+
6372
get [CONN_PERSIST_SYMBOL](): PersistedConn<CP, CS> {
6473
return this.#stateManager.persist;
6574
}
@@ -109,9 +118,7 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
109118
return false;
110119
}
111120
return (
112-
(this.#actor as any)[
113-
ACTOR_INSTANCE_PERSIST_SYMBOL
114-
].hibernatableConns.findIndex((conn: any) =>
121+
this.#actor.persist.hibernatableConns.findIndex((conn: any) =>
115122
arrayBuffersEqual(
116123
conn.hibernatableRequestId,
117124
hibernatableRequestId,
@@ -191,6 +198,8 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
191198
* @see {@link https://rivet.dev/docs/events|Events Documentation}
192199
*/
193200
send(eventName: string, ...args: unknown[]) {
201+
this.#assertConnected();
202+
194203
this.#actor.inspector.emitter.emit("eventFired", {
195204
type: "event",
196205
eventName,
@@ -244,7 +253,7 @@ export class Conn<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
244253
});
245254
}
246255

247-
this.#actor.connDisconnected(this, true);
256+
this.#actor.connectionManager.connDisconnected(this);
248257
} else {
249258
this.#actor.rLog.warn({
250259
msg: "missing connection driver state for disconnect",

rivetkit-typescript/packages/rivetkit/src/actor/conn/state-manager.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ export class StateManager<CP, CS> {
134134
});
135135

136136
// Notify actor that this connection has changed
137-
this.#conn[CONN_ACTOR_SYMBOL].markConnChanged(this.#conn);
137+
this.#conn[CONN_ACTOR_SYMBOL].connectionManager.markConnChanged(
138+
this.#conn,
139+
);
138140
}
139141
}

rivetkit-typescript/packages/rivetkit/src/actor/contexts/actor.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ export class ActorContext<
6060
* @param args - The arguments to send with the event.
6161
*/
6262
broadcast<Args extends Array<unknown>>(name: string, ...args: Args): void {
63-
this.#actor.broadcast(name, ...args);
63+
this.#actor.eventManager.broadcast(name, ...args);
6464
return;
6565
}
6666

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
import type { AnyDatabaseProvider } from "../database";
2-
import { ConnInitContext } from "./conn-init";
2+
import { ConnContext } from "./conn";
33

44
/**
55
* Context for the onConnect lifecycle hook.
66
* Called when a connection is successfully established.
77
*/
88
export class OnConnectContext<
99
TState,
10+
TConnParams,
11+
TConnState,
1012
TVars,
1113
TInput,
1214
TDatabase extends AnyDatabaseProvider,
13-
> extends ConnInitContext<TState, TVars, TInput, TDatabase> {}
15+
> extends ConnContext<
16+
TState,
17+
TConnParams,
18+
TConnState,
19+
TVars,
20+
TInput,
21+
TDatabase
22+
> {}

0 commit comments

Comments
 (0)