diff --git a/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v2.bare b/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v2.bare index 836c67cf06..c0403ebc70 100644 --- a/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v2.bare +++ b/rivetkit-typescript/packages/rivetkit/schemas/actor-persist/v2.bare @@ -51,5 +51,5 @@ type PersistedActor struct { state: data connections: list scheduledEvents: list - hibernatableWebSocket: list + hibernatableWebSockets: list } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/raw-websocket.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/raw-websocket.ts new file mode 100644 index 0000000000..969c575a7f --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/raw-websocket.ts @@ -0,0 +1,52 @@ +import type { AnyConn } from "@/actor/conn/mod"; +import type { AnyActorInstance } from "@/actor/instance/mod"; +import type { UniversalWebSocket } from "@/common/websocket-interface"; +import type { ConnDriver, DriverReadyState } from "../driver"; + +/** + * Creates a raw WebSocket connection driver. + * + * This driver is used for raw WebSocket connections that don't use the RivetKit protocol. + * Unlike the standard WebSocket driver, this doesn't have sendMessage since raw WebSockets + * don't handle messages from the RivetKit protocol - they handle messages directly in the + * actor's onWebSocket handler. + */ +export function createRawWebSocketSocket( + requestId: string, + requestIdBuf: ArrayBuffer | undefined, + hibernatable: boolean, + websocket: UniversalWebSocket, + closePromise: Promise, +): ConnDriver { + return { + requestId, + requestIdBuf, + hibernatable, + + // No sendMessage implementation since this is a raw WebSocket that doesn't + // handle messages from the RivetKit protocol + + disconnect: async ( + _actor: AnyActorInstance, + _conn: AnyConn, + reason?: string, + ) => { + // Close socket + websocket.close(1000, reason); + + // Wait for socket to close gracefully + await closePromise; + }, + + terminate: () => { + (websocket as any).terminate?.(); + }, + + getConnectionReadyState: ( + _actor: AnyActorInstance, + _conn: AnyConn, + ): DriverReadyState | undefined => { + return websocket.readyState; + }, + }; +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts index 40ebe6d00c..ffc097b3d5 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts @@ -41,7 +41,6 @@ enum CanSleep { NotReady, ActiveConns, ActiveHonoHttpRequests, - ActiveRawWebSockets, } /** Actor type alias with all `any` types. Used for `extends` in classes referencing this actor. */ @@ -100,7 +99,6 @@ export class ActorInstance { // MARK: - HTTP/WebSocket Tracking #activeHonoHttpRequests = 0; - #activeRawWebSockets = new Set(); // MARK: - Deprecated (kept for compatibility) #schedule!: Schedule; @@ -673,13 +671,9 @@ export class ActorInstance { try { const stateBeforeHandler = this.#stateManager.persistChanged; - // Track active websocket - this.#activeRawWebSockets.add(websocket); + // Reset sleep timer when handling WebSocket this.#resetSleepTimer(); - // Setup WebSocket event handlers (simplified for brevity) - this.#setupWebSocketHandlers(websocket); - // Handle WebSocket await this.#config.onWebSocket(this.actorContext, websocket, opts); @@ -958,18 +952,6 @@ export class ActorInstance { } } - #setupWebSocketHandlers(websocket: UniversalWebSocket) { - // Simplified WebSocket handler setup - // Full implementation would track hibernatable websockets - const onSocketClosed = () => { - this.#activeRawWebSockets.delete(websocket); - this.#resetSleepTimer(); - }; - - websocket.addEventListener("close", onSocketClosed); - websocket.addEventListener("error", onSocketClosed); - } - #resetSleepTimer() { if (this.#config.options.noSleep || !this.#sleepingSupported) return; if (this.#stopCalled) return; @@ -1001,8 +983,6 @@ export class ActorInstance { if (!this.#ready) return CanSleep.NotReady; if (this.#activeHonoHttpRequests > 0) return CanSleep.ActiveHonoHttpRequests; - if (this.#activeRawWebSockets.size > 0) - return CanSleep.ActiveRawWebSockets; for (const _conn of this.#connectionManager.connections.values()) { return CanSleep.ActiveConns; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts index b955801972..1a8885d6cb 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts @@ -36,6 +36,7 @@ import { promiseWithResolvers, } from "@/utils"; import { createHttpSocket } from "./conn/drivers/http"; +import { createRawWebSocketSocket } from "./conn/drivers/raw-websocket"; import { createWebSocketSocket } from "./conn/drivers/websocket"; import type { ActorDriver } from "./driver"; import { loggerWithoutContext } from "./log"; @@ -383,12 +384,19 @@ export async function handleRawWebSocketHandler( ): Promise { const actor = await actorDriver.loadActor(actorId); + // Promise used to wait for the websocket close in `disconnect` + const closePromiseResolvers = promiseWithResolvers(); + + // Track connection outside of scope for cleanup + let createdConn: AnyConn | undefined; + // Return WebSocket event handlers return { - onOpen: (evt: any, ws: any) => { + onOpen: async (evt: any, ws: any) => { // Extract rivetRequestId provided by engine runner const rivetRequestId = evt?.rivetRequestId; const isHibernatable = + !!rivetRequestId && actor[ ACTOR_INSTANCE_PERSIST_SYMBOL ].hibernatableConns.findIndex((conn) => @@ -424,10 +432,36 @@ export async function handleRawWebSocketHandler( toUrl: newRequest.url, }); - // Call the actor's onWebSocket handler with the adapted WebSocket - actor.handleWebSocket(adapter, { - request: newRequest, - }); + try { + // Create connection using actor.createConn - this handles deduplication for hibernatable connections + const requestId = rivetRequestId + ? String(rivetRequestId) + : crypto.randomUUID(); + const conn = await actor.createConn( + createRawWebSocketSocket( + requestId, + rivetRequestId, + isHibernatable, + adapter, + closePromiseResolvers.promise, + ), + {}, // No parameters for raw WebSocket + newRequest, + ); + + createdConn = conn; + + // Call the actor's onWebSocket handler with the adapted WebSocket + actor.handleWebSocket(adapter, { + request: newRequest, + }); + } catch (error) { + actor.rLog.error({ + msg: "failed to create raw WebSocket connection", + error: String(error), + }); + ws.close(1011, "Failed to create connection"); + } }, onMessage: (event: any, ws: any) => { // Find the adapter for this WebSocket @@ -442,6 +476,15 @@ export async function handleRawWebSocketHandler( if (adapter) { adapter._handleClose(evt?.code || 1006, evt?.reason || ""); } + + // Resolve the close promise + closePromiseResolvers.resolve(); + + // Clean up the connection + if (createdConn) { + const wasClean = evt?.wasClean || evt?.code === 1000; + actor.connDisconnected(createdConn, wasClean); + } }, onError: (error: any, ws: any) => { // Find the adapter for this WebSocket diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts index 74b3544edc..55512ea142 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -76,8 +76,8 @@ export class EngineActorDriver implements ActorDriver { #runnerStopped: PromiseWithResolvers = promiseWithResolvers(); #isRunnerStopped: boolean = false; - // WebSocket message acknowledgment debouncing - #wsAckQueue: Map< + // WebSocket message acknowledgment debouncing for hibernatable websockets + #hibernatableWebSocketAckQueue: Map< string, { requestIdBuf: ArrayBuffer; messageIndex: number } > = new Map(); @@ -176,7 +176,9 @@ export class EngineActorDriver implements ActorDriver { msg: "checking hibernatable websockets", requestId: idToStr(requestId), existingHibernatableWebSockets: hibernatableArray.length, + actorId, }); + const existingWs = hibernatableArray.find((conn) => arrayBuffersEqual(conn.hibernatableRequestId, requestId), ); @@ -184,14 +186,19 @@ export class EngineActorDriver implements ActorDriver { // Determine configuration for new WS let hibernationConfig: HibernationConfig; if (existingWs) { + // Convert msgIndex to number, treating -1 as undefined (no messages processed yet) + const lastMsgIndex = + existingWs.msgIndex >= 0n + ? Number(existingWs.msgIndex) + : undefined; logger().debug({ msg: "found existing hibernatable websocket", requestId: idToStr(requestId), - lastMsgIndex: existingWs.msgIndex, + lastMsgIndex: lastMsgIndex ?? -1, }); hibernationConfig = { enabled: true, - lastMsgIndex: Number(existingWs.msgIndex), + lastMsgIndex, }; } else { logger().debug({ @@ -268,6 +275,7 @@ export class EngineActorDriver implements ActorDriver { logger().debug({ msg: "updated existing hibernatable websocket timestamp", requestId: idToStr(requestId), + currentMsgIndex: existingWs.msgIndex, }); existingWs.lastSeenTimestamp = Date.now(); } else if (path === PATH_CONNECT) { @@ -277,7 +285,7 @@ export class EngineActorDriver implements ActorDriver { msg: "will create hibernatable conn when connection is created", requestId: idToStr(requestId), }); - // Note: The actual hibernatable connection is created in instance.ts + // Note: The actual hibernatable connection is created in connection-manager.ts // when createConn is called with a hibernatable requestId } @@ -302,7 +310,10 @@ export class EngineActorDriver implements ActorDriver { // // Gateway timeout configured to 30s // https://github.com/rivet-dev/rivet/blob/222dae87e3efccaffa2b503de40ecf8afd4e31eb/engine/packages/pegboard-gateway/src/shared_state.rs#L17 - this.#wsAckFlushInterval = setInterval(() => this.#flushWsAcks(), 1000); + this.#wsAckFlushInterval = setInterval( + () => this.#flushHibernatableWebSocketAcks(), + 1000, + ); } async #loadActorHandler(actorId: string): Promise { @@ -321,17 +332,17 @@ export class EngineActorDriver implements ActorDriver { return handler.actor; } - #flushWsAcks(): void { - if (this.#wsAckQueue.size === 0) return; + #flushHibernatableWebSocketAcks(): void { + if (this.#hibernatableWebSocketAckQueue.size === 0) return; for (const { requestIdBuf: requestId, messageIndex: index, - } of this.#wsAckQueue.values()) { + } of this.#hibernatableWebSocketAckQueue.values()) { this.#runner.sendWebsocketMessageAck(requestId, index); } - this.#wsAckQueue.clear(); + this.#hibernatableWebSocketAckQueue.clear(); } getContext(actorId: string): DriverContext { @@ -608,39 +619,171 @@ export class EngineActorDriver implements ActorDriver { invariant(event.rivetRequestId, "missing rivetRequestId"); invariant(event.rivetMessageIndex, "missing rivetMessageIndex"); - // Track only the highest seen message index per request - // Convert ArrayBuffer to string for Map key - const currentEntry = this.#wsAckQueue.get(requestId); - if (currentEntry) { - if (event.rivetMessageIndex > currentEntry.messageIndex) { - currentEntry.messageIndex = event.rivetMessageIndex; - } else { - logger().warn({ - msg: "received lower index than ack queue for message", + // Handle hibernatable WebSockets: + // - Save msgIndex for WS restoration + // - Queue WS acks + const actorHandler = this.#actors.get(actorId); + if (actorHandler?.actor) { + const hibernatableWs = actorHandler.actor[ + ACTOR_INSTANCE_PERSIST_SYMBOL + ].hibernatableConns.find((conn: any) => + arrayBuffersEqual(conn.hibernatableRequestId, requestIdBuf), + ); + + if (hibernatableWs) { + // Update msgIndex for next WebSocket open msgIndex restoration + const oldMsgIndex = hibernatableWs.msgIndex; + hibernatableWs.msgIndex = event.rivetMessageIndex; + hibernatableWs.lastSeenTimestamp = Date.now(); + + logger().debug({ + msg: "updated hibernatable websocket msgIndex in engine driver", requestId, - queuedMessageIndex: currentEntry, - eventMessageIndex: event.rivetMessageIndex, + oldMsgIndex: oldMsgIndex.toString(), + newMsgIndex: event.rivetMessageIndex, + actorId, }); + + // Track msgIndex for sending acks + const currentEntry = + this.#hibernatableWebSocketAckQueue.get(requestId); + if (currentEntry) { + const previousIndex = currentEntry.messageIndex; + + // Warn about any non-sequential message indices + if (event.rivetMessageIndex !== previousIndex + 1) { + logger().warn({ + msg: "websocket message index out of sequence", + requestId, + actorId, + previousIndex, + expectedIndex: previousIndex + 1, + receivedIndex: event.rivetMessageIndex, + sequenceType: + event.rivetMessageIndex < previousIndex + ? "regressed" + : event.rivetMessageIndex === + previousIndex + ? "duplicate" + : "gap/skipped", + gap: + event.rivetMessageIndex > previousIndex + ? event.rivetMessageIndex - + previousIndex - + 1 + : 0, + }); + } + + // Update to the highest seen index + if (event.rivetMessageIndex > previousIndex) { + currentEntry.messageIndex = event.rivetMessageIndex; + } + } else { + this.#hibernatableWebSocketAckQueue.set(requestId, { + requestIdBuf, + messageIndex: event.rivetMessageIndex, + }); + } } } else { - this.#wsAckQueue.set(requestId, { - requestIdBuf, + // Warn if we receive a message for a hibernatable websocket but can't find the actor + logger().warn({ + msg: "received websocket message but actor not found for hibernatable tracking", + actorId, + requestId, messageIndex: event.rivetMessageIndex, + hasHandler: !!actorHandler, + hasActor: !!actorHandler?.actor, }); } }); websocket.addEventListener("close", (event) => { // Flush any pending acks before closing - this.#flushWsAcks(); + this.#flushHibernatableWebSocketAcks(); + + // Clean up hibernatable WebSocket + this.#cleanupHibernatableWebSocket( + actorId, + requestIdBuf, + requestId, + "close", + event, + ); + wsHandlerPromise.then((x) => x.onClose?.(event, wsContext)); }); websocket.addEventListener("error", (event) => { + // Clean up hibernatable WebSocket on error + this.#cleanupHibernatableWebSocket( + actorId, + requestIdBuf, + requestId, + "error", + event, + ); + wsHandlerPromise.then((x) => x.onError?.(event, wsContext)); }); } + /** + * Helper method to clean up hibernatable WebSocket entries + * Eliminates duplication between close and error handlers + */ + #cleanupHibernatableWebSocket( + actorId: string, + requestIdBuf: ArrayBuffer, + requestId: string, + eventType: "close" | "error", + event?: any, + ) { + const actorHandler = this.#actors.get(actorId); + if (actorHandler?.actor) { + const hibernatableArray = + actorHandler.actor[ACTOR_INSTANCE_PERSIST_SYMBOL] + .hibernatableConns; + const wsIndex = hibernatableArray.findIndex((conn: any) => + arrayBuffersEqual(conn.hibernatableRequestId, requestIdBuf), + ); + + if (wsIndex !== -1) { + const removed = hibernatableArray.splice(wsIndex, 1); + const logData: any = { + msg: `removed hibernatable websocket on ${eventType}`, + requestId, + actorId, + removedMsgIndex: + removed[0]?.msgIndex?.toString() ?? "unknown", + }; + // Add error context if this is an error event + if (eventType === "error" && event) { + logData.error = event; + } + logger().debug(logData); + } + } else { + // Warn if actor not found during cleanup + const warnData: any = { + msg: `websocket ${eventType === "close" ? "closed" : "error"} but actor not found for hibernatable cleanup`, + actorId, + requestId, + hasHandler: !!actorHandler, + hasActor: !!actorHandler?.actor, + }; + // Add error context if this is an error event + if (eventType === "error" && event) { + warnData.error = event; + } + logger().warn(warnData); + } + + // Also remove from ack queue + this.#hibernatableWebSocketAckQueue.delete(requestId); + } + startSleep(actorId: string) { this.#runner.sleepActor(actorId); } @@ -700,7 +843,7 @@ export class EngineActorDriver implements ActorDriver { } // Flush any remaining acks - this.#flushWsAcks(); + this.#flushHibernatableWebSocketAcks(); await this.#runner.shutdown(immediate); } diff --git a/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts b/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts index 068c7fc5b6..d6fa27cefb 100644 --- a/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts +++ b/rivetkit-typescript/packages/rivetkit/src/schemas/actor-persist/versioned.ts @@ -17,7 +17,7 @@ const migrations = new Map>([ ...conn, hibernatableRequestId: null, })), - hibernatableWebSocket: [], + hibernatableWebSockets: [], }), ], [ @@ -30,7 +30,7 @@ const migrations = new Map>([ for (const conn of v2Data.connections) { if (conn.hibernatableRequestId) { // Find the matching hibernatable WebSocket - const ws = v2Data.hibernatableWebSocket.find((ws) => + const ws = v2Data.hibernatableWebSockets.find((ws) => Buffer.from(ws.requestId).equals( Buffer.from(conn.hibernatableRequestId!), ),