diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-http-request-properties.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-http-request-properties.ts index 5c89bca0a8..d441841d8a 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-http-request-properties.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-http-request-properties.ts @@ -1,9 +1,9 @@ -import { type ActorContext, actor } from "rivetkit"; +import { actor, type RequestContext } from "rivetkit"; export const rawHttpRequestPropertiesActor = actor({ actions: {}, onRequest( - ctx: ActorContext, + ctx: RequestContext, request: Request, ) { // Extract all relevant Request properties diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-http.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-http.ts index 5723337c2b..66b7506900 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-http.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/raw-http.ts @@ -1,12 +1,12 @@ import { Hono } from "hono"; -import { type ActorContext, actor } from "rivetkit"; +import { actor, type RequestContext } from "rivetkit"; export const rawHttpActor = actor({ state: { requestCount: 0, }, onRequest( - ctx: ActorContext, + ctx: RequestContext, request: Request, ) { const url = new URL(request.url); @@ -111,7 +111,7 @@ export const rawHttpHonoActor = actor({ return { router }; }, onRequest( - ctx: ActorContext, + ctx: RequestContext, request: Request, ) { // Use the Hono router from vars diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/config.ts b/rivetkit-typescript/packages/rivetkit/src/actor/config.ts index 704693d05f..9625cb229c 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/config.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/config.ts @@ -3,6 +3,8 @@ import type { UniversalWebSocket } from "@/common/websocket-interface"; import type { Conn } from "./conn/mod"; import type { ActionContext } from "./contexts/action"; import type { ActorContext } from "./contexts/actor"; +import type { RequestContext } from "./contexts/request"; +import type { WebSocketContext } from "./contexts/websocket"; import type { AnyDatabaseProvider } from "./database"; export type InitContext = ActorContext< @@ -407,11 +409,13 @@ interface BaseActorConfig< * This handler receives raw HTTP requests made to `/actors/{actorName}/http/*` endpoints. * Use this hook to handle custom HTTP patterns, REST APIs, or other HTTP-based protocols. * + * @param c The request context with access to the connection * @param request The raw HTTP request object + * @param opts Additional options * @returns A Response object to send back, or void to continue with default routing */ onRequest?: ( - c: ActorContext< + c: RequestContext< TState, TConnParams, TConnState, @@ -420,7 +424,6 @@ interface BaseActorConfig< TDatabase >, request: Request, - opts: {}, ) => Response | Promise; /** @@ -429,11 +432,12 @@ interface BaseActorConfig< * This handler receives WebSocket connections made to `/actors/{actorName}/websocket/*` endpoints. * Use this hook to handle custom WebSocket protocols, binary streams, or other WebSocket-based communication. * + * @param c The WebSocket context with access to the connection * @param websocket The raw WebSocket connection - * @param request The original HTTP upgrade request + * @param opts Additional options including the original HTTP upgrade request */ onWebSocket?: ( - c: ActorContext< + c: WebSocketContext< TState, TConnParams, TConnState, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/request.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/request.ts new file mode 100644 index 0000000000..1aa07b9afc --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/request.ts @@ -0,0 +1,183 @@ +import type { ActorKey } from "@/actor/mod"; +import type { Client } from "@/client/client"; +import type { Logger } from "@/common/log"; +import type { Registry } from "@/registry/mod"; +import type { Conn, ConnId } from "../conn/mod"; +import type { AnyDatabaseProvider, InferDatabaseClient } from "../database"; +import type { SaveStateOptions } from "../instance/state-manager"; +import type { Schedule } from "../schedule"; +import type { ActorContext } from "./actor"; + +/** + * Context for raw HTTP request handlers (onRequest). + * + * @typeParam TState - The actor state type + * @typeParam TConnParams - The connection parameters type + * @typeParam TConnState - The connection state type + * @typeParam TVars - The actor variables type + * @typeParam TInput - The actor input type + * @typeParam TDatabase - The database provider type + */ +export class RequestContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase extends AnyDatabaseProvider, +> { + #actorContext: ActorContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase + >; + + /** + * Should not be called directly. + * + * @param actorContext - The actor context + * @param conn - The connection associated with the request + */ + constructor( + actorContext: ActorContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase + >, + public readonly conn: Conn< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase + >, + ) { + this.#actorContext = actorContext; + } + + /** + * Get the actor state + */ + get state(): TState { + return this.#actorContext.state; + } + + /** + * Get the actor variables + */ + get vars(): TVars { + return this.#actorContext.vars; + } + + /** + * Broadcasts an event to all connected clients. + */ + broadcast(name: string, ...args: any[]): void { + this.#actorContext.broadcast(name, ...args); + } + + /** + * Gets the logger instance. + */ + get log(): Logger { + return this.#actorContext.log; + } + + /** + * Gets actor ID. + */ + get actorId(): string { + return this.#actorContext.actorId; + } + + /** + * Gets the actor name. + */ + get name(): string { + return this.#actorContext.name; + } + + /** + * Gets the actor key. + */ + get key(): ActorKey { + return this.#actorContext.key; + } + + /** + * Gets the region. + */ + get region(): string { + return this.#actorContext.region; + } + + /** + * Gets the scheduler. + */ + get schedule(): Schedule { + return this.#actorContext.schedule; + } + + /** + * Gets the map of connections. + */ + get conns(): Map< + ConnId, + Conn + > { + return this.#actorContext.conns; + } + + /** + * Returns the client for the given registry. + */ + client>(): Client { + return this.#actorContext.client(); + } + + /** + * @experimental + */ + get db(): InferDatabaseClient { + return this.#actorContext.db; + } + + /** + * Forces the state to get saved. + */ + async saveState(opts: SaveStateOptions): Promise { + return this.#actorContext.saveState(opts); + } + + /** + * Prevents the actor from sleeping until promise is complete. + */ + waitUntil(promise: Promise): void { + this.#actorContext.waitUntil(promise); + } + + /** + * AbortSignal that fires when the actor is stopping. + */ + get abortSignal(): AbortSignal { + return this.#actorContext.abortSignal; + } + + /** + * Forces the actor to sleep. + * + * Not supported on all drivers. + * + * @experimental + */ + sleep() { + this.#actorContext.sleep(); + } +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/websocket.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/websocket.ts new file mode 100644 index 0000000000..8403432531 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/websocket.ts @@ -0,0 +1,183 @@ +import type { ActorKey } from "@/actor/mod"; +import type { Client } from "@/client/client"; +import type { Logger } from "@/common/log"; +import type { Registry } from "@/registry/mod"; +import type { Conn, ConnId } from "../conn/mod"; +import type { AnyDatabaseProvider, InferDatabaseClient } from "../database"; +import type { SaveStateOptions } from "../instance/state-manager"; +import type { Schedule } from "../schedule"; +import type { ActorContext } from "./actor"; + +/** + * Context for raw WebSocket handlers (onWebSocket). + * + * @typeParam TState - The actor state type + * @typeParam TConnParams - The connection parameters type + * @typeParam TConnState - The connection state type + * @typeParam TVars - The actor variables type + * @typeParam TInput - The actor input type + * @typeParam TDatabase - The database provider type + */ +export class WebSocketContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase extends AnyDatabaseProvider, +> { + #actorContext: ActorContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase + >; + + /** + * Should not be called directly. + * + * @param actorContext - The actor context + * @param conn - The connection associated with the WebSocket + */ + constructor( + actorContext: ActorContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase + >, + public readonly conn: Conn< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase + >, + ) { + this.#actorContext = actorContext; + } + + /** + * Get the actor state + */ + get state(): TState { + return this.#actorContext.state; + } + + /** + * Get the actor variables + */ + get vars(): TVars { + return this.#actorContext.vars; + } + + /** + * Broadcasts an event to all connected clients. + */ + broadcast(name: string, ...args: any[]): void { + this.#actorContext.broadcast(name, ...args); + } + + /** + * Gets the logger instance. + */ + get log(): Logger { + return this.#actorContext.log; + } + + /** + * Gets actor ID. + */ + get actorId(): string { + return this.#actorContext.actorId; + } + + /** + * Gets the actor name. + */ + get name(): string { + return this.#actorContext.name; + } + + /** + * Gets the actor key. + */ + get key(): ActorKey { + return this.#actorContext.key; + } + + /** + * Gets the region. + */ + get region(): string { + return this.#actorContext.region; + } + + /** + * Gets the scheduler. + */ + get schedule(): Schedule { + return this.#actorContext.schedule; + } + + /** + * Gets the map of connections. + */ + get conns(): Map< + ConnId, + Conn + > { + return this.#actorContext.conns; + } + + /** + * Returns the client for the given registry. + */ + client>(): Client { + return this.#actorContext.client(); + } + + /** + * @experimental + */ + get db(): InferDatabaseClient { + return this.#actorContext.db; + } + + /** + * Forces the state to get saved. + */ + async saveState(opts: SaveStateOptions): Promise { + return this.#actorContext.saveState(opts); + } + + /** + * Prevents the actor from sleeping until promise is complete. + */ + waitUntil(promise: Promise): void { + this.#actorContext.waitUntil(promise); + } + + /** + * AbortSignal that fires when the actor is stopping. + */ + get abortSignal(): AbortSignal { + return this.#actorContext.abortSignal; + } + + /** + * Forces the actor to sleep. + * + * Not supported on all drivers. + * + * @experimental + */ + sleep() { + this.#actorContext.sleep(); + } +} diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts index b709e1dbac..ac8e6fe117 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts @@ -24,6 +24,8 @@ import { } from "../conn/mod"; import { ActionContext } from "../contexts/action"; import { ActorContext } from "../contexts/actor"; +import { RequestContext } from "../contexts/request"; +import { WebSocketContext } from "../contexts/websocket"; import type { AnyDatabaseProvider, InferDatabaseClient } from "../database"; import type { ActorDriver } from "../driver"; import * as errors from "../errors"; @@ -646,8 +648,8 @@ export class ActorInstance { // MARK: - HTTP/WebSocket Handlers async handleRawRequest( + conn: Conn, request: Request, - opts: Record, ): Promise { this.#assertReady(); @@ -656,11 +658,8 @@ export class ActorInstance { } try { - const response = await this.#config.onRequest( - this.actorContext, - request, - opts, - ); + const ctx = new RequestContext(this.actorContext, conn); + const response = await this.#config.onRequest(ctx, request); if (!response) { throw new errors.InvalidRequestHandlerResponse(); } @@ -677,6 +676,7 @@ export class ActorInstance { } async handleRawWebSocket( + conn: Conn, websocket: UniversalWebSocket, opts: { request: Request }, ): Promise { @@ -693,7 +693,8 @@ export class ActorInstance { this.#resetSleepTimer(); // Handle WebSocket - await this.#config.onWebSocket(this.actorContext, websocket, opts); + const ctx = new WebSocketContext(this.actorContext, conn); + await this.#config.onWebSocket(ctx, websocket, opts); // Save state if changed if (this.#stateManager.persistChanged && !stateBeforeHandler) { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts index 93e0952330..3a329b243a 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts @@ -76,6 +76,8 @@ export type * from "./config"; export type { Conn } from "./conn/mod"; export type { ActionContext } from "./contexts/action"; export type { ActorContext } from "./contexts/actor"; +export type { RequestContext } from "./contexts/request"; +export type { WebSocketContext } from "./contexts/websocket"; export type { ActionContextOf, ActorContextOf, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts index 05ec9978a8..131f413225 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts @@ -392,7 +392,7 @@ export async function handleRawRequest( createdConn = conn; - return await actor.handleRawRequest(req, {}); + return await actor.handleRawRequest(conn, req); } finally { // Clean up the connection after the request completes if (createdConn) { @@ -474,7 +474,7 @@ export async function handleRawWebSocket( createdConn = conn; // Call the actor's onWebSocket handler with the adapted WebSocket - actor.handleRawWebSocket(adapter, { + actor.handleRawWebSocket(conn, adapter, { request: newRequest, }); } catch (error) {