Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ type PersistedActor struct {
state: data
connections: list<PersistedConnection>
scheduledEvents: list<PersistedScheduleEvent>
hibernatableWebSocket: list<PersistedHibernatableWebSocket>
hibernatableWebSockets: list<PersistedHibernatableWebSocket>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import type { AnyConn } from "@/actor/conn/mod";
import type { AnyActorInstance } from "@/actor/instance/mod";
import type { UniversalWebSocket } from "@/common/websocket-interface";
import type { ConnDriver, DriverReadyState } from "../driver";

/**
* Creates a raw WebSocket connection driver.
*
* This driver is used for raw WebSocket connections that don't use the RivetKit protocol.
* Unlike the standard WebSocket driver, this doesn't have sendMessage since raw WebSockets
* don't handle messages from the RivetKit protocol - they handle messages directly in the
* actor's onWebSocket handler.
*/
export function createRawWebSocketSocket(
requestId: string,
requestIdBuf: ArrayBuffer | undefined,
hibernatable: boolean,
websocket: UniversalWebSocket,
closePromise: Promise<void>,
): ConnDriver {
return {
requestId,
requestIdBuf,
hibernatable,

// No sendMessage implementation since this is a raw WebSocket that doesn't
// handle messages from the RivetKit protocol

disconnect: async (
_actor: AnyActorInstance,
_conn: AnyConn,
reason?: string,
) => {
// Close socket
websocket.close(1000, reason);

// Wait for socket to close gracefully
await closePromise;
},

terminate: () => {
(websocket as any).terminate?.();
},

getConnectionReadyState: (
_actor: AnyActorInstance,
_conn: AnyConn,
): DriverReadyState | undefined => {
return websocket.readyState;
},
};
}
22 changes: 1 addition & 21 deletions rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ enum CanSleep {
NotReady,
ActiveConns,
ActiveHonoHttpRequests,
ActiveRawWebSockets,
}

/** Actor type alias with all `any` types. Used for `extends` in classes referencing this actor. */
Expand Down Expand Up @@ -100,7 +99,6 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {

// MARK: - HTTP/WebSocket Tracking
#activeHonoHttpRequests = 0;
#activeRawWebSockets = new Set<UniversalWebSocket>();

// MARK: - Deprecated (kept for compatibility)
#schedule!: Schedule;
Expand Down Expand Up @@ -673,13 +671,9 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
try {
const stateBeforeHandler = this.#stateManager.persistChanged;

// Track active websocket
this.#activeRawWebSockets.add(websocket);
// Reset sleep timer when handling WebSocket
this.#resetSleepTimer();

// Setup WebSocket event handlers (simplified for brevity)
this.#setupWebSocketHandlers(websocket);

// Handle WebSocket
await this.#config.onWebSocket(this.actorContext, websocket, opts);

Expand Down Expand Up @@ -958,18 +952,6 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
}
}

#setupWebSocketHandlers(websocket: UniversalWebSocket) {
// Simplified WebSocket handler setup
// Full implementation would track hibernatable websockets
const onSocketClosed = () => {
this.#activeRawWebSockets.delete(websocket);
this.#resetSleepTimer();
};

websocket.addEventListener("close", onSocketClosed);
websocket.addEventListener("error", onSocketClosed);
}

#resetSleepTimer() {
if (this.#config.options.noSleep || !this.#sleepingSupported) return;
if (this.#stopCalled) return;
Expand Down Expand Up @@ -1001,8 +983,6 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
if (!this.#ready) return CanSleep.NotReady;
if (this.#activeHonoHttpRequests > 0)
return CanSleep.ActiveHonoHttpRequests;
if (this.#activeRawWebSockets.size > 0)
return CanSleep.ActiveRawWebSockets;

for (const _conn of this.#connectionManager.connections.values()) {
return CanSleep.ActiveConns;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import {
promiseWithResolvers,
} from "@/utils";
import { createHttpSocket } from "./conn/drivers/http";
import { createRawWebSocketSocket } from "./conn/drivers/raw-websocket";
import { createWebSocketSocket } from "./conn/drivers/websocket";
import type { ActorDriver } from "./driver";
import { loggerWithoutContext } from "./log";
Expand Down Expand Up @@ -383,12 +384,19 @@ export async function handleRawWebSocketHandler(
): Promise<UpgradeWebSocketArgs> {
const actor = await actorDriver.loadActor(actorId);

// Promise used to wait for the websocket close in `disconnect`
const closePromiseResolvers = promiseWithResolvers<void>();

// Track connection outside of scope for cleanup
let createdConn: AnyConn | undefined;

// Return WebSocket event handlers
return {
onOpen: (evt: any, ws: any) => {
onOpen: async (evt: any, ws: any) => {
// Extract rivetRequestId provided by engine runner
const rivetRequestId = evt?.rivetRequestId;
const isHibernatable =
!!rivetRequestId &&
actor[
ACTOR_INSTANCE_PERSIST_SYMBOL
].hibernatableConns.findIndex((conn) =>
Expand Down Expand Up @@ -424,10 +432,36 @@ export async function handleRawWebSocketHandler(
toUrl: newRequest.url,
});

// Call the actor's onWebSocket handler with the adapted WebSocket
actor.handleWebSocket(adapter, {
request: newRequest,
});
try {
// Create connection using actor.createConn - this handles deduplication for hibernatable connections
const requestId = rivetRequestId
? String(rivetRequestId)
: crypto.randomUUID();
const conn = await actor.createConn(
createRawWebSocketSocket(
requestId,
rivetRequestId,
isHibernatable,
adapter,
closePromiseResolvers.promise,
),
{}, // No parameters for raw WebSocket
newRequest,
);

createdConn = conn;

// Call the actor's onWebSocket handler with the adapted WebSocket
actor.handleWebSocket(adapter, {
request: newRequest,
});
} catch (error) {
actor.rLog.error({
msg: "failed to create raw WebSocket connection",
error: String(error),
});
ws.close(1011, "Failed to create connection");
}
},
onMessage: (event: any, ws: any) => {
// Find the adapter for this WebSocket
Expand All @@ -442,6 +476,15 @@ export async function handleRawWebSocketHandler(
if (adapter) {
adapter._handleClose(evt?.code || 1006, evt?.reason || "");
}

// Resolve the close promise
closePromiseResolvers.resolve();

// Clean up the connection
if (createdConn) {
const wasClean = evt?.wasClean || evt?.code === 1000;
actor.connDisconnected(createdConn, wasClean);
}
},
onError: (error: any, ws: any) => {
// Find the adapter for this WebSocket
Expand Down
Loading
Loading