diff --git a/examples/linear-coding-agent/src/actors/coding-agent/linear.ts b/examples/linear-coding-agent/src/actors/coding-agent/linear.ts index 945be19a5..a92cd2315 100644 --- a/examples/linear-coding-agent/src/actors/coding-agent/linear.ts +++ b/examples/linear-coding-agent/src/actors/coding-agent/linear.ts @@ -593,7 +593,7 @@ async function createPRForIssue( c.state.github.prInfo = await github.createPullRequest( c, `${title}`, // Just use the issue title - `Closes ${issueFriendlyId}\n\nImplements changes requested in Linear issue.\n\n${summary}\n\n*Authored by ActorCore Coding Agent*`, // Include "Closes" keyword + `Closes ${issueFriendlyId}\n\nImplements changes requested in Linear issue.\n\n${summary}\n\n*Authored by RivetKit Coding Agent*`, // Include "Closes" keyword c.state.github.branchName, c.state.github.baseBranch, ); @@ -713,4 +713,4 @@ export async function getIssueStatus( console.error(`[LINEAR] Failed to get issue status:`, error); return null; } -} \ No newline at end of file +} diff --git a/packages/actor/scripts/dump-openapi.ts b/packages/actor/scripts/dump-openapi.ts index 1a69a6b06..24319616d 100644 --- a/packages/actor/scripts/dump-openapi.ts +++ b/packages/actor/scripts/dump-openapi.ts @@ -11,6 +11,7 @@ import { OpenAPIHono } from "@hono/zod-openapi"; import { VERSION } from "@/utils"; import * as fs from "node:fs/promises"; import { resolve } from "node:path"; +import type { ClientDriver } from "@/client/client"; function main() { const appConfig: AppConfig = AppConfigSchema.parse({ actors: {} }); @@ -40,13 +41,20 @@ function main() { }, }; - const managerRouter = createManagerRouter(appConfig, driverConfig, { - proxyMode: { - inline: { - handlers: sharedConnectionHandlers, - }, - }, - }) as unknown as OpenAPIHono; + const clientDriver: ClientDriver = { + action: unimplemented, + resolveActorId: unimplemented, + connectWebSocket: unimplemented, + connectSse: unimplemented, + sendHttpMessage: unimplemented, + }; + + const managerRouter = createManagerRouter( + appConfig, + driverConfig, + clientDriver, + {}, + ) as unknown as OpenAPIHono; const openApiDoc = managerRouter.getOpenAPIDocument({ openapi: "3.0.0", diff --git a/packages/actor/src/actor/router-endpoints.ts b/packages/actor/src/actor/router-endpoints.ts index 9d1b6f3ea..dd07e57b8 100644 --- a/packages/actor/src/actor/router-endpoints.ts +++ b/packages/actor/src/actor/router-endpoints.ts @@ -46,7 +46,7 @@ export interface ConnectSseOutput { } export interface ActionOpts { - req: HonoRequest; + req?: HonoRequest; params: unknown; actionName: string; actionArgs: unknown[]; @@ -266,86 +266,6 @@ export async function handleSseConnect( }); } -/** - * Creates an action handler - */ -export async function handleAction( - c: HonoContext, - appConfig: AppConfig, - driverConfig: DriverConfig, - handler: (opts: ActionOpts) => Promise, - actionName: string, - actorId: string, -) { - const encoding = getRequestEncoding(c.req, false); - const parameters = getRequestConnParams(c.req, appConfig, driverConfig); - - logger().debug("handling action", { actionName, encoding }); - - // Validate incoming request - let actionArgs: unknown[]; - if (encoding === "json") { - try { - actionArgs = await c.req.json(); - } catch (err) { - throw new errors.InvalidActionRequest("Invalid JSON"); - } - - if (!Array.isArray(actionArgs)) { - throw new errors.InvalidActionRequest("Action arguments must be an array"); - } - } else if (encoding === "cbor") { - try { - const value = await c.req.arrayBuffer(); - const uint8Array = new Uint8Array(value); - const deserialized = await deserialize( - uint8Array as unknown as InputData, - encoding, - ); - - // Validate using the action schema - const result = protoHttpAction.ActionRequestSchema.safeParse(deserialized); - if (!result.success) { - throw new errors.InvalidActionRequest("Invalid action request format"); - } - - actionArgs = result.data.a; - } catch (err) { - throw new errors.InvalidActionRequest( - `Invalid binary format: ${stringifyError(err)}`, - ); - } - } else { - return assertUnreachable(encoding); - } - - // Invoke the action - const result = await handler({ - req: c.req, - params: parameters, - actionName: actionName, - actionArgs: actionArgs, - actorId, - }); - - // Encode the response - if (encoding === "json") { - return c.json(result.output as Record); - } else if (encoding === "cbor") { - // Use serialize from serde.ts instead of custom encoder - const responseData = { - o: result.output, // Use the format expected by ResponseOkSchema - }; - const serialized = serialize(responseData, encoding); - - return c.body(serialized as Uint8Array, 200, { - "Content-Type": "application/octet-stream", - }); - } else { - return assertUnreachable(encoding); - } -} - /** * Create a connection message handler */ diff --git a/packages/actor/src/app/inline-client-driver.ts b/packages/actor/src/app/inline-client-driver.ts new file mode 100644 index 000000000..51351651c --- /dev/null +++ b/packages/actor/src/app/inline-client-driver.ts @@ -0,0 +1,178 @@ +import * as errors from "@/actor/errors"; +import { logger } from "./log"; +import type { EventSource } from "eventsource"; +import type * as wsToServer from "@/actor/protocol/message/to-server"; +import { type Encoding, serialize } from "@/actor/protocol/serde"; +import { type ConnectionHandlers } from "@/actor/router-endpoints"; +import { HonoRequest, type Context as HonoContext, type Next } from "hono"; +import invariant from "invariant"; +import { ClientDriver } from "@/client/client"; +import { ManagerDriver } from "@/manager/driver"; +import { ActorQuery } from "@/manager/protocol/query"; + +/** + * Client driver that calls the manager driver inline. + * + * This driver can access private resources. + * + * This driver serves a double purpose as: + * - Providing the client for the internal requests + * - Provide the driver for the manager HTTP router (see manager/router.ts) + */ +export function createInlineClientDriver( + managerDriver: ManagerDriver, + connHandlers: ConnectionHandlers, +): ClientDriver { + //// Lazily import the dynamic imports so we don't have to turn `createClient` in to an aysnc fn + //const dynamicImports = (async () => { + // // Import dynamic dependencies + // const [WebSocket, EventSource] = await Promise.all([ + // importWebSocket(), + // importEventSource(), + // ]); + // return { + // WebSocket, + // EventSource, + // }; + //})(); + + const driver: ClientDriver = { + action: async = unknown[], Response = unknown>( + req: HonoRequest | undefined, + actorQuery: ActorQuery, + encoding: Encoding, + params: unknown, + actionName: string, + ...args: Args + ): Promise => { + // Get the actor ID and meta + const { actorId, meta } = await queryActor( + req, + actorQuery, + managerDriver, + ); + logger().debug("found actor for action", { actorId, meta }); + invariant(actorId, "Missing actor ID"); + + logger().debug("handling action", { actionName, encoding }); + + // Invoke the action + const { output } = await connHandlers.onAction({ + req, + params, + actionName, + actionArgs: args, + actorId, + }); + + return output as Response; + }, + + resolveActorId: async ( + req: HonoRequest | undefined, + actorQuery: ActorQuery, + _encodingKind: Encoding, + ): Promise => { + // Get the actor ID and meta + const { actorId } = await queryActor(req, actorQuery, managerDriver); + logger().debug("resolved actor", { actorId }); + invariant(actorId, "Missing actor ID"); + + return actorId; + }, + + connectWebSocket: async ( + req: HonoRequest | undefined, + actorQuery: ActorQuery, + encodingKind: Encoding, + ): Promise => { + throw "UNIMPLEMENTED"; + }, + + connectSse: async ( + req: HonoRequest | undefined, + actorQuery: ActorQuery, + encodingKind: Encoding, + params: unknown, + ): Promise => { + throw "UNIMPLEMENTED"; + }, + + sendHttpMessage: async ( + req: HonoRequest | undefined, + actorId: string, + encoding: Encoding, + connectionId: string, + connectionToken: string, + message: wsToServer.ToServer, + ): Promise => { + throw "UNIMPLEMENTED"; + }, + }; + + return driver; +} + +/** + * Query the manager driver to get or create an actor based on the provided query + */ +export async function queryActor( + req: HonoRequest | undefined, + query: ActorQuery, + driver: ManagerDriver, +): Promise<{ actorId: string; meta?: unknown }> { + logger().debug("querying actor", { query }); + let actorOutput: { actorId: string; meta?: unknown }; + if ("getForId" in query) { + const output = await driver.getForId({ + req, + actorId: query.getForId.actorId, + }); + if (!output) throw new errors.ActorNotFound(query.getForId.actorId); + actorOutput = output; + } else if ("getForKey" in query) { + const existingActor = await driver.getWithKey({ + req, + name: query.getForKey.name, + key: query.getForKey.key, + }); + if (!existingActor) { + throw new errors.ActorNotFound( + `${query.getForKey.name}:${JSON.stringify(query.getForKey.key)}`, + ); + } + actorOutput = existingActor; + } else if ("getOrCreateForKey" in query) { + const getOrCreateOutput = await driver.getOrCreateWithKey({ + req, + name: query.getOrCreateForKey.name, + key: query.getOrCreateForKey.key, + input: query.getOrCreateForKey.input, + region: query.getOrCreateForKey.region, + }); + actorOutput = { + actorId: getOrCreateOutput.actorId, + meta: getOrCreateOutput.meta, + }; + } else if ("create" in query) { + const createOutput = await driver.createActor({ + req, + name: query.create.name, + key: query.create.key, + input: query.create.input, + region: query.create.region, + }); + actorOutput = { + actorId: createOutput.actorId, + meta: createOutput.meta, + }; + } else { + throw new errors.InvalidRequest("Invalid query format"); + } + + logger().debug("actor query result", { + actorId: actorOutput.actorId, + meta: actorOutput.meta, + }); + return { actorId: actorOutput.actorId, meta: actorOutput.meta }; +} diff --git a/packages/actor/src/app/log.ts b/packages/actor/src/app/log.ts new file mode 100644 index 000000000..d56c5aa74 --- /dev/null +++ b/packages/actor/src/app/log.ts @@ -0,0 +1,7 @@ +import { getLogger } from "@/common//log"; + +export const LOGGER_NAME = "actor-app"; + +export function logger() { + return getLogger(LOGGER_NAME); +} diff --git a/packages/actor/src/client/actor-conn.ts b/packages/actor/src/client/actor-conn.ts index 9de47b400..35c7d43c8 100644 --- a/packages/actor/src/client/actor-conn.ts +++ b/packages/actor/src/client/actor-conn.ts @@ -1,13 +1,8 @@ import type { AnyActorDefinition } from "@/actor/definition"; -import type { Transport } from "@/actor/protocol/message/mod"; import type * as wsToClient from "@/actor/protocol/message/to-client"; import type * as wsToServer from "@/actor/protocol/message/to-server"; import type { Encoding } from "@/actor/protocol/serde"; -import { importEventSource } from "@/common/eventsource"; -import { MAX_CONN_PARAMS_SIZE } from "@/common/network"; -import { httpUserAgent } from "@/utils"; import { assertUnreachable, stringifyError } from "@/common/utils"; -import { importWebSocket } from "@/common/websocket"; import type { ActorQuery } from "@/manager/protocol/query"; import * as cbor from "cbor-x"; import pRetry from "p-retry"; @@ -20,14 +15,6 @@ import { import * as errors from "./errors"; import { logger } from "./log"; import { type WebSocketMessage as ConnMessage, messageLength, serializeWithEncoding } from "./utils"; -import { - HEADER_ACTOR_ID, - HEADER_ACTOR_QUERY, - HEADER_CONN_ID, - HEADER_CONN_TOKEN, - HEADER_ENCODING, - HEADER_CONN_PARAMS, -} from "@/actor/router-endpoints"; import type { EventSource } from "eventsource"; import { ActorDefinitionActions } from "./actor-common"; @@ -251,6 +238,7 @@ enc async #connectWebSocket() { const ws = await this.#driver.connectWebSocket( + undefined, this.#actorQuery, this.#encodingKind, ); @@ -281,6 +269,7 @@ enc async #connectSse() { const eventSource = await this.#driver.connectSse( + undefined, this.#actorQuery, this.#encodingKind, this.#params, @@ -649,6 +638,7 @@ enc throw new errors.InternalError("Missing connection ID or token."); const res = await this.#driver.sendHttpMessage( + undefined, this.#actorId, this.#encodingKind, this.#connectionId, diff --git a/packages/actor/src/client/actor-handle.ts b/packages/actor/src/client/actor-handle.ts index c1c458808..0da6fa69b 100644 --- a/packages/actor/src/client/actor-handle.ts +++ b/packages/actor/src/client/actor-handle.ts @@ -61,6 +61,7 @@ export class ActorHandleRaw { ...args: Args ): Promise { return await this.#driver.action( + undefined, this.#actorQuery, this.#encodingKind, this.#params, @@ -105,6 +106,7 @@ export class ActorHandleRaw { ) { // TODO: const actorId = await this.#driver.resolveActorId( + undefined, this.#actorQuery, this.#encodingKind, ); diff --git a/packages/actor/src/client/client.ts b/packages/actor/src/client/client.ts index 91145fdb6..a48e0d23f 100644 --- a/packages/actor/src/client/client.ts +++ b/packages/actor/src/client/client.ts @@ -15,6 +15,7 @@ import type { AnyActorDefinition } from "@/actor/definition"; import type * as wsToServer from "@/actor/protocol/message/to-server"; import type { EventSource } from "eventsource"; import { createHttpClientDriver } from "./http-client-driver"; +import { HonoRequest } from "hono"; /** Extract the actor registry from the app definition. */ export type ExtractActorsFromApp> = @@ -159,6 +160,7 @@ export const TRANSPORT_SYMBOL = Symbol("transport"); export interface ClientDriver { action = unknown[], Response = unknown>( + req: HonoRequest | undefined, actorQuery: ActorQuery, encoding: Encoding, params: unknown, @@ -166,19 +168,23 @@ export interface ClientDriver { ...args: Args ): Promise; resolveActorId( + req: HonoRequest | undefined, actorQuery: ActorQuery, - encodingKind: Encoding, + encoding: Encoding, ): Promise; connectWebSocket( + req: HonoRequest | undefined, actorQuery: ActorQuery, encodingKind: Encoding, ): Promise; connectSse( + req: HonoRequest | undefined, actorQuery: ActorQuery, encodingKind: Encoding, params: unknown, ): Promise; sendHttpMessage( + req: HonoRequest | undefined, actorId: string, encoding: Encoding, connectionId: string, @@ -353,6 +359,7 @@ export class ClientRaw { // Create the actor const actorId = await this.#driver.resolveActorId( + undefined, createQuery, this.#encodingKind, ); diff --git a/packages/actor/src/client/http-client-driver.ts b/packages/actor/src/client/http-client-driver.ts index 94cdb0c36..cb3af8a77 100644 --- a/packages/actor/src/client/http-client-driver.ts +++ b/packages/actor/src/client/http-client-driver.ts @@ -25,12 +25,15 @@ import { import type { ActionRequest } from "@/actor/protocol/http/action"; import type { ActionResponse } from "@/actor/protocol/message/to-client"; import { ClientDriver } from "./client"; +import { HonoRequest } from "hono"; /** * Client driver that communicates with the manager via HTTP. + * + * This driver cannot access private resources. */ export function createHttpClientDriver(managerEndpoint: string): ClientDriver { - // Lazily import the dynamic imports so we don't have to turn `createClient` in to an aysnc fn + // Lazily import the dynamic imports so we don't have to turn `createClient` in to an async fn const dynamicImports = (async () => { // Import dynamic dependencies const [WebSocket, EventSource] = await Promise.all([ @@ -45,6 +48,7 @@ export function createHttpClientDriver(managerEndpoint: string): ClientDriver { const driver: ClientDriver = { action: async = unknown[], Response = unknown>( + req: HonoRequest | undefined, actorQuery: ActorQuery, encoding: Encoding, params: unknown, @@ -77,6 +81,7 @@ export function createHttpClientDriver(managerEndpoint: string): ClientDriver { }, resolveActorId: async ( + req: HonoRequest | undefined, actorQuery: ActorQuery, encodingKind: Encoding, ): Promise => { @@ -112,6 +117,7 @@ export function createHttpClientDriver(managerEndpoint: string): ClientDriver { }, connectWebSocket: async ( + req: HonoRequest | undefined, actorQuery: ActorQuery, encodingKind: Encoding, ): Promise => { @@ -140,6 +146,7 @@ export function createHttpClientDriver(managerEndpoint: string): ClientDriver { }, connectSse: async ( + req: HonoRequest | undefined, actorQuery: ActorQuery, encodingKind: Encoding, params: unknown, @@ -170,6 +177,7 @@ export function createHttpClientDriver(managerEndpoint: string): ClientDriver { }, sendHttpMessage: async ( + req: HonoRequest | undefined, actorId: string, encoding: Encoding, connectionId: string, diff --git a/packages/actor/src/driver-test-suite/tests/actor-handle.ts b/packages/actor/src/driver-test-suite/tests/actor-handle.ts index 74ffdf6e6..26030d1ec 100644 --- a/packages/actor/src/driver-test-suite/tests/actor-handle.ts +++ b/packages/actor/src/driver-test-suite/tests/actor-handle.ts @@ -1,6 +1,6 @@ -import { describe, test, expect, vi } from "vitest"; +import { describe, test, expect } from "vitest"; import type { DriverTestConfig } from "../mod"; -import { setupDriverTest, waitFor } from "../utils"; +import { setupDriverTest } from "../utils"; import { COUNTER_APP_PATH, LIFECYCLE_APP_PATH, @@ -205,8 +205,10 @@ export function runActorHandleTests(driverTestConfig: DriverTestConfig) { ); // Create a normal handle to view events - const viewHandle = client.counter.getOrCreate(["test-lifecycle-action"]); - + const viewHandle = client.counter.getOrCreate([ + "test-lifecycle-action", + ]); + // Initial state should only have onStart const initialEvents = await viewHandle.getEvents(); expect(initialEvents).toContain("onStart"); @@ -217,35 +219,47 @@ export function runActorHandleTests(driverTestConfig: DriverTestConfig) { // Create a handle with trackLifecycle enabled for testing Action calls const trackingHandle = client.counter.getOrCreate( ["test-lifecycle-action"], - { params: { trackLifecycle: true } } + { params: { trackLifecycle: true } }, ); - + // Make an Action call await trackingHandle.increment(5); - + // Check that it triggered the lifecycle hooks const eventsAfterAction = await viewHandle.getEvents(); - + // Should have onBeforeConnect, onConnect, and onDisconnect for the Action call expect(eventsAfterAction).toContain("onBeforeConnect"); expect(eventsAfterAction).toContain("onConnect"); expect(eventsAfterAction).toContain("onDisconnect"); - + // Each should have count 1 - expect(eventsAfterAction.filter(e => e === "onBeforeConnect").length).toBe(1); - expect(eventsAfterAction.filter(e => e === "onConnect").length).toBe(1); - expect(eventsAfterAction.filter(e => e === "onDisconnect").length).toBe(1); - + expect( + eventsAfterAction.filter((e) => e === "onBeforeConnect").length, + ).toBe(1); + expect(eventsAfterAction.filter((e) => e === "onConnect").length).toBe( + 1, + ); + expect( + eventsAfterAction.filter((e) => e === "onDisconnect").length, + ).toBe(1); + // Make another Action call await trackingHandle.increment(10); - + // Check that it triggered another set of lifecycle hooks const eventsAfterSecondAction = await viewHandle.getEvents(); - + // Each hook should now have count 2 - expect(eventsAfterSecondAction.filter(e => e === "onBeforeConnect").length).toBe(2); - expect(eventsAfterSecondAction.filter(e => e === "onConnect").length).toBe(2); - expect(eventsAfterSecondAction.filter(e => e === "onDisconnect").length).toBe(2); + expect( + eventsAfterSecondAction.filter((e) => e === "onBeforeConnect").length, + ).toBe(2); + expect( + eventsAfterSecondAction.filter((e) => e === "onConnect").length, + ).toBe(2); + expect( + eventsAfterSecondAction.filter((e) => e === "onDisconnect").length, + ).toBe(2); }); test("should trigger lifecycle hooks for each Action call across multiple handles", async (c) => { @@ -256,31 +270,33 @@ export function runActorHandleTests(driverTestConfig: DriverTestConfig) { ); // Create a normal handle to view events - const viewHandle = client.counter.getOrCreate(["test-lifecycle-multi-handle"]); - + const viewHandle = client.counter.getOrCreate([ + "test-lifecycle-multi-handle", + ]); + // Create two tracking handles to the same actor const trackingHandle1 = client.counter.getOrCreate( ["test-lifecycle-multi-handle"], - { params: { trackLifecycle: true } } + { params: { trackLifecycle: true } }, ); - + const trackingHandle2 = client.counter.getOrCreate( ["test-lifecycle-multi-handle"], - { params: { trackLifecycle: true } } + { params: { trackLifecycle: true } }, ); - + // Make Action calls on both handles await trackingHandle1.increment(5); await trackingHandle2.increment(10); - + // Check lifecycle hooks const events = await viewHandle.getEvents(); - + // Should have 1 onStart, 2 each of onBeforeConnect, onConnect, and onDisconnect - expect(events.filter(e => e === "onStart").length).toBe(1); - expect(events.filter(e => e === "onBeforeConnect").length).toBe(2); - expect(events.filter(e => e === "onConnect").length).toBe(2); - expect(events.filter(e => e === "onDisconnect").length).toBe(2); + expect(events.filter((e) => e === "onStart").length).toBe(1); + expect(events.filter((e) => e === "onBeforeConnect").length).toBe(2); + expect(events.filter((e) => e === "onConnect").length).toBe(2); + expect(events.filter((e) => e === "onDisconnect").length).toBe(2); }); }); }); diff --git a/packages/actor/src/manager/driver.ts b/packages/actor/src/manager/driver.ts index 714893881..9bdb71c39 100644 --- a/packages/actor/src/manager/driver.ts +++ b/packages/actor/src/manager/driver.ts @@ -1,6 +1,6 @@ import type { ActorKey } from "@/common/utils"; import type { ManagerInspector } from "@/inspector/manager"; -import type { Env, Context as HonoContext } from "hono"; +import type { Env, Context as HonoContext, HonoRequest } from "hono"; export interface ManagerDriver { getForId(input: GetForIdInput): Promise; @@ -11,18 +11,18 @@ export interface ManagerDriver { inspector?: ManagerInspector; } export interface GetForIdInput { - c?: HonoContext; + req?: HonoRequest; actorId: string; } export interface GetWithKeyInput { - c?: HonoContext; + req?: HonoRequest; name: string; key: ActorKey; } export interface GetOrCreateWithKeyInput { - c?: HonoContext; + req?: HonoRequest; name: string; key: ActorKey; input?: unknown; @@ -30,7 +30,7 @@ export interface GetOrCreateWithKeyInput { } export interface CreateInput { - c?: HonoContext; + req?: HonoRequest; name: string; key: ActorKey; input?: unknown; diff --git a/packages/actor/src/manager/router.ts b/packages/actor/src/manager/router.ts index 18fa24a9a..a4a718ecb 100644 --- a/packages/actor/src/manager/router.ts +++ b/packages/actor/src/manager/router.ts @@ -1,14 +1,14 @@ import * as errors from "@/actor/errors"; -import type * as protoHttpResolve from "@/actor/protocol/http/resolve"; -import type { ToClient } from "@/actor/protocol/message/to-client"; -import { type Encoding, serialize } from "@/actor/protocol/serde"; +import * as protoHttpResolve from "@/actor/protocol/http/resolve"; +import * as protoHttpAction from "@/actor/protocol/http/action"; +import { + deserialize, + type Encoding, + InputData, + serialize, +} from "@/actor/protocol/serde"; import { - type ConnectionHandlers, getRequestEncoding, - handleConnectionMessage, - handleAction, - handleSseConnect, - handleWebSocketConnect, HEADER_ACTOR_ID, HEADER_CONN_ID, HEADER_CONN_PARAMS, @@ -17,6 +17,7 @@ import { HEADER_ACTOR_QUERY, ALL_HEADERS, getRequestQuery, + getRequestConnParams, } from "@/actor/router-endpoints"; import { assertUnreachable } from "@/actor/utils"; import type { AppConfig } from "@/app/config"; @@ -25,7 +26,7 @@ import { handleRouteNotFound, loggerMiddleware, } from "@/common/router"; -import { deconstructError } from "@/common/utils"; +import { stringifyError } from "@/common/utils"; import type { DriverConfig } from "@/driver-helpers/config"; import { type ManagerInspectorConnHandler, @@ -36,9 +37,6 @@ import { OpenAPIHono } from "@hono/zod-openapi"; import { z } from "@hono/zod-openapi"; import { createRoute } from "@hono/zod-openapi"; import { cors } from "hono/cors"; -import { streamSSE } from "hono/streaming"; -import type { WSContext } from "hono/ws"; -import invariant from "invariant"; import type { ManagerDriver } from "./driver"; import { logger } from "./log"; import { @@ -49,30 +47,9 @@ import { } from "./protocol/query"; import type { ActorQuery } from "./protocol/query"; import { VERSION } from "@/utils"; +import { ClientDriver } from "@/app/inline-client-driver"; -type ProxyMode = - | { - inline: { - handlers: ConnectionHandlers; - }; - } - | { - custom: { - onProxyRequest: OnProxyRequest; - onProxyWebSocket: OnProxyWebSocket; - }; - }; - -export type BuildProxyEndpoint = (c: HonoContext, actorId: string) => string; - -export type OnProxyRequest = ( - c: HonoContext, - actorRequest: Request, - actorId: string, - meta?: unknown, -) => Promise; - -export type OnProxyWebSocket = ( +export type ProxyWebSocketHandler = ( c: HonoContext, path: string, actorId: string, @@ -80,8 +57,19 @@ export type OnProxyWebSocket = ( ) => Promise; type ManagerRouterHandler = { + proxyWebSocket( + req: HonoRequest, + actorQuery: ActorQuery, + encodingKind: Encoding, + ): Promise; + + proxySse( + req: HonoRequest, + actorQuery: ActorQuery, + encodingKind: Encoding, + params: unknown, + ): Promise; onConnectInspector?: ManagerInspectorConnHandler; - proxyMode: ProxyMode; }; const OPENAPI_ENCODING = z.string().openapi({ @@ -133,6 +121,7 @@ function buildOpenApiResponses(schema: T) { export function createManagerRouter( appConfig: AppConfig, driverConfig: DriverConfig, + clientDriver: ClientDriver, handler: ManagerRouterHandler, ) { if (!driverConfig.drivers?.manager) { @@ -214,7 +203,7 @@ export function createManagerRouter( responses: buildOpenApiResponses(ResolveResponseSchema), }); - app.openapi(resolveRoute, (c) => handleResolveRequest(c, driver)); + app.openapi(resolveRoute, (c) => handleResolveRequest(c, clientDriver)); } // GET /actors/connect/websocket @@ -327,7 +316,14 @@ export function createManagerRouter( }); app.openapi(actionRoute, (c) => - handleActionRequest(c, appConfig, driverConfig, driver, handler), + handleActionRequest( + c, + appConfig, + driverConfig, + driver, + clientDriver, + handler, + ), ); } @@ -396,70 +392,6 @@ export function createManagerRouter( return app as unknown as Hono; } -/** - * Query the manager driver to get or create an actor based on the provided query - */ -export async function queryActor( - c: HonoContext, - query: ActorQuery, - driver: ManagerDriver, -): Promise<{ actorId: string; meta?: unknown }> { - logger().debug("querying actor", { query }); - let actorOutput: { actorId: string; meta?: unknown }; - if ("getForId" in query) { - const output = await driver.getForId({ - c, - actorId: query.getForId.actorId, - }); - if (!output) throw new errors.ActorNotFound(query.getForId.actorId); - actorOutput = output; - } else if ("getForKey" in query) { - const existingActor = await driver.getWithKey({ - c, - name: query.getForKey.name, - key: query.getForKey.key, - }); - if (!existingActor) { - throw new errors.ActorNotFound( - `${query.getForKey.name}:${JSON.stringify(query.getForKey.key)}`, - ); - } - actorOutput = existingActor; - } else if ("getOrCreateForKey" in query) { - const getOrCreateOutput = await driver.getOrCreateWithKey({ - c, - name: query.getOrCreateForKey.name, - key: query.getOrCreateForKey.key, - input: query.getOrCreateForKey.input, - region: query.getOrCreateForKey.region, - }); - actorOutput = { - actorId: getOrCreateOutput.actorId, - meta: getOrCreateOutput.meta, - }; - } else if ("create" in query) { - const createOutput = await driver.createActor({ - c, - name: query.create.name, - key: query.create.key, - input: query.create.input, - region: query.create.region, - }); - actorOutput = { - actorId: createOutput.actorId, - meta: createOutput.meta, - }; - } else { - throw new errors.InvalidRequest("Invalid query format"); - } - - logger().debug("actor query result", { - actorId: actorOutput.actorId, - meta: actorOutput.meta, - }); - return { actorId: actorOutput.actorId, meta: actorOutput.meta }; -} - /** * Handle SSE connection request */ @@ -470,110 +402,111 @@ async function handleSseConnectRequest( driver: ManagerDriver, handler: ManagerRouterHandler, ): Promise { - let encoding: Encoding | undefined; - try { - encoding = getRequestEncoding(c.req, false); - logger().debug("sse connection request received", { encoding }); - - const params = ConnectRequestSchema.safeParse({ - query: getRequestQuery(c, false), - encoding: c.req.header(HEADER_ENCODING), - params: c.req.header(HEADER_CONN_PARAMS), - }); - - if (!params.success) { - logger().error("invalid connection parameters", { - error: params.error, - }); - throw new errors.InvalidRequest(params.error); - } - - const query = params.data.query; - - // Get the actor ID and meta - const { actorId, meta } = await queryActor(c, query, driver); - invariant(actorId, "Missing actor ID"); - logger().debug("sse connection to actor", { actorId, meta }); - - // Handle based on mode - if ("inline" in handler.proxyMode) { - logger().debug("using inline proxy mode for sse connection"); - // Use the shared SSE handler - return await handleSseConnect( - c, - appConfig, - driverConfig, - handler.proxyMode.inline.handlers.onConnectSse, - actorId, - ); - } else if ("custom" in handler.proxyMode) { - logger().debug("using custom proxy mode for sse connection"); - const url = new URL("http://actor/connect/sse"); - const proxyRequest = new Request(url, c.req.raw); - proxyRequest.headers.set(HEADER_ENCODING, params.data.encoding); - if (params.data.connParams) { - proxyRequest.headers.set(HEADER_CONN_PARAMS, params.data.connParams); - } - return await handler.proxyMode.custom.onProxyRequest( - c, - proxyRequest, - actorId, - meta, - ); - } else { - assertUnreachable(handler.proxyMode); - } - } catch (error) { - // If we receive an error during setup, we send the error and close the socket immediately - // - // We have to return the error over SSE since SSE clients cannot read vanilla HTTP responses - - const { code, message, metadata } = deconstructError(error, logger(), { - sseEvent: "setup", - }); - - return streamSSE(c, async (stream) => { - try { - if (encoding) { - // Serialize and send the connection error - const errorMsg: ToClient = { - b: { - e: { - c: code, - m: message, - md: metadata, - }, - }, - }; - - // Send the error message to the client - const serialized = serialize(errorMsg, encoding); - await stream.writeSSE({ - data: - typeof serialized === "string" - ? serialized - : Buffer.from(serialized).toString("base64"), - }); - } else { - // We don't know the encoding, send an error and close - await stream.writeSSE({ - data: code, - event: "error", - }); - } - } catch (serializeError) { - logger().error("failed to send error to sse client", { - error: serializeError, - }); - await stream.writeSSE({ - data: "internal error during error handling", - event: "error", - }); - } - - // Stream will exit completely once function exits - }); - } + throw "UNIMPLEMENTED"; + //let encoding: Encoding | undefined; + //try { + // encoding = getRequestEncoding(c.req, false); + // logger().debug("sse connection request received", { encoding }); + // + // const params = ConnectRequestSchema.safeParse({ + // query: getRequestQuery(c, false), + // encoding: c.req.header(HEADER_ENCODING), + // params: c.req.header(HEADER_CONN_PARAMS), + // }); + // + // if (!params.success) { + // logger().error("invalid connection parameters", { + // error: params.error, + // }); + // throw new errors.InvalidRequest(params.error); + // } + // + // const query = params.data.query; + // + // // Get the actor ID and meta + // const { actorId, meta } = await queryActor(c, query, driver); + // invariant(actorId, "Missing actor ID"); + // logger().debug("sse connection to actor", { actorId, meta }); + // + // // Handle based on mode + // if ("inline" in handler.proxyMode) { + // logger().debug("using inline proxy mode for sse connection"); + // // Use the shared SSE handler + // return await handleSseConnect( + // c, + // appConfig, + // driverConfig, + // handler.proxyMode.inline.handlers.onConnectSse, + // actorId, + // ); + // } else if ("custom" in handler.proxyMode) { + // logger().debug("using custom proxy mode for sse connection"); + // const url = new URL("http://actor/connect/sse"); + // const proxyRequest = new Request(url, c.req.raw); + // proxyRequest.headers.set(HEADER_ENCODING, params.data.encoding); + // if (params.data.connParams) { + // proxyRequest.headers.set(HEADER_CONN_PARAMS, params.data.connParams); + // } + // return await handler.proxyMode.custom.onProxyRequest( + // c, + // proxyRequest, + // actorId, + // meta, + // ); + // } else { + // assertUnreachable(handler.proxyMode); + // } + //} catch (error) { + // // If we receive an error during setup, we send the error and close the socket immediately + // // + // // We have to return the error over SSE since SSE clients cannot read vanilla HTTP responses + // + // const { code, message, metadata } = deconstructError(error, logger(), { + // sseEvent: "setup", + // }); + // + // return streamSSE(c, async (stream) => { + // try { + // if (encoding) { + // // Serialize and send the connection error + // const errorMsg: ToClient = { + // b: { + // e: { + // c: code, + // m: message, + // md: metadata, + // }, + // }, + // }; + // + // // Send the error message to the client + // const serialized = serialize(errorMsg, encoding); + // await stream.writeSSE({ + // data: + // typeof serialized === "string" + // ? serialized + // : Buffer.from(serialized).toString("base64"), + // }); + // } else { + // // We don't know the encoding, send an error and close + // await stream.writeSSE({ + // data: code, + // event: "error", + // }); + // } + // } catch (serializeError) { + // logger().error("failed to send error to sse client", { + // error: serializeError, + // }); + // await stream.writeSSE({ + // data: "internal error during error handling", + // event: "error", + // }); + // } + // + // // Stream will exit completely once function exits + // }); + //} } /** @@ -591,103 +524,103 @@ async function handleWebSocketConnectRequest( driver: ManagerDriver, handler: ManagerRouterHandler, ): Promise { - invariant(upgradeWebSocket, "WebSockets not supported"); - - let encoding: Encoding | undefined; - try { - logger().debug("websocket connection request received"); - - // We can't use the standard headers with WebSockets - // - // All other information will be sent over the socket itself, since that data needs to be E2EE - const params = ConnectWebSocketRequestSchema.safeParse({ - query: getRequestQuery(c, true), - encoding: c.req.query("encoding"), - }); - if (!params.success) { - logger().error("invalid connection parameters", { - error: params.error, - }); - throw new errors.InvalidRequest(params.error); - } - - // Get the actor ID and meta - const { actorId, meta } = await queryActor(c, params.data.query, driver); - logger().debug("found actor for websocket connection", { actorId, meta }); - invariant(actorId, "missing actor id"); - - if ("inline" in handler.proxyMode) { - logger().debug("using inline proxy mode for websocket connection"); - invariant( - handler.proxyMode.inline.handlers.onConnectWebSocket, - "onConnectWebSocket not provided", - ); - - const onConnectWebSocket = - handler.proxyMode.inline.handlers.onConnectWebSocket; - return upgradeWebSocket((c) => { - return handleWebSocketConnect( - c, - appConfig, - driverConfig, - onConnectWebSocket, - actorId, - )(); - })(c, noopNext()); - } else if ("custom" in handler.proxyMode) { - logger().debug("using custom proxy mode for websocket connection"); - return await handler.proxyMode.custom.onProxyWebSocket( - c, - `/connect/websocket?encoding=${params.data.encoding}`, - actorId, - meta, - ); - } else { - assertUnreachable(handler.proxyMode); - } - } catch (error) { - // If we receive an error during setup, we send the error and close the socket immediately - // - // We have to return the error over WS since WebSocket clients cannot read vanilla HTTP responses - - const { code, message, metadata } = deconstructError(error, logger(), { - wsEvent: "setup", - }); - - return await upgradeWebSocket(() => ({ - onOpen: async (_evt: unknown, ws: WSContext) => { - if (encoding) { - try { - // Serialize and send the connection error - const errorMsg: ToClient = { - b: { - e: { - c: code, - m: message, - md: metadata, - }, - }, - }; - - // Send the error message to the client - const serialized = serialize(errorMsg, encoding); - ws.send(serialized); - - // Close the connection with an error code - ws.close(1011, code); - } catch (serializeError) { - logger().error("failed to send error to websocket client", { - error: serializeError, - }); - ws.close(1011, "internal error during error handling"); - } - } else { - // We don't know the encoding so we send what we can - ws.close(1011, code); - } - }, - }))(c, noopNext()); - } + //invariant(upgradeWebSocket, "WebSockets not supported"); + // + //let encoding: Encoding | undefined; + //try { + // logger().debug("websocket connection request received"); + // + // // We can't use the standard headers with WebSockets + // // + // // All other information will be sent over the socket itself, since that data needs to be E2EE + // const params = ConnectWebSocketRequestSchema.safeParse({ + // query: getRequestQuery(c, true), + // encoding: c.req.query("encoding"), + // }); + // if (!params.success) { + // logger().error("invalid connection parameters", { + // error: params.error, + // }); + // throw new errors.InvalidRequest(params.error); + // } + // + // // Get the actor ID and meta + // const { actorId, meta } = await queryActor(c, params.data.query, driver); + // logger().debug("found actor for websocket connection", { actorId, meta }); + // invariant(actorId, "missing actor id"); + // + // if ("inline" in handler.proxyMode) { + // logger().debug("using inline proxy mode for websocket connection"); + // invariant( + // handler.proxyMode.inline.handlers.onConnectWebSocket, + // "onConnectWebSocket not provided", + // ); + // + // const onConnectWebSocket = + // handler.proxyMode.inline.handlers.onConnectWebSocket; + // return upgradeWebSocket((c) => { + // return handleWebSocketConnect( + // c, + // appConfig, + // driverConfig, + // onConnectWebSocket, + // actorId, + // )(); + // })(c, noopNext()); + // } else if ("custom" in handler.proxyMode) { + // logger().debug("using custom proxy mode for websocket connection"); + // return await handler.proxyMode.custom.onProxyWebSocket( + // c, + // `/connect/websocket?encoding=${params.data.encoding}`, + // actorId, + // meta, + // ); + // } else { + // assertUnreachable(handler.proxyMode); + // } + //} catch (error) { + // // If we receive an error during setup, we send the error and close the socket immediately + // // + // // We have to return the error over WS since WebSocket clients cannot read vanilla HTTP responses + // + // const { code, message, metadata } = deconstructError(error, logger(), { + // wsEvent: "setup", + // }); + // + // return await upgradeWebSocket(() => ({ + // onOpen: async (_evt: unknown, ws: WSContext) => { + // if (encoding) { + // try { + // // Serialize and send the connection error + // const errorMsg: ToClient = { + // b: { + // e: { + // c: code, + // m: message, + // md: metadata, + // }, + // }, + // }; + // + // // Send the error message to the client + // const serialized = serialize(errorMsg, encoding); + // ws.send(serialized); + // + // // Close the connection with an error code + // ws.close(1011, code); + // } catch (serializeError) { + // logger().error("failed to send error to websocket client", { + // error: serializeError, + // }); + // ws.close(1011, "internal error during error handling"); + // } + // } else { + // // We don't know the encoding so we send what we can + // ws.close(1011, code); + // } + // }, + // }))(c, noopNext()); + //} } /** @@ -698,61 +631,62 @@ async function handleMessageRequest( appConfig: AppConfig, handler: ManagerRouterHandler, ): Promise { - logger().debug("connection message request received"); - try { - const params = ConnMessageRequestSchema.safeParse({ - actorId: c.req.header(HEADER_ACTOR_ID), - connId: c.req.header(HEADER_CONN_ID), - encoding: c.req.header(HEADER_ENCODING), - connToken: c.req.header(HEADER_CONN_TOKEN), - }); - if (!params.success) { - logger().error("invalid connection parameters", { - error: params.error, - }); - throw new errors.InvalidRequest(params.error); - } - const { actorId, connId, encoding, connToken } = params.data; - - // Handle based on mode - if ("inline" in handler.proxyMode) { - logger().debug("using inline proxy mode for connection message"); - // Use shared connection message handler with direct parameters - return handleConnectionMessage( - c, - appConfig, - handler.proxyMode.inline.handlers.onConnMessage, - connId, - connToken as string, - actorId, - ); - } else if ("custom" in handler.proxyMode) { - logger().debug("using custom proxy mode for connection message"); - const url = new URL(`http://actor/connections/message`); - - const proxyRequest = new Request(url, c.req.raw); - proxyRequest.headers.set(HEADER_ENCODING, encoding); - proxyRequest.headers.set(HEADER_CONN_ID, connId); - proxyRequest.headers.set(HEADER_CONN_TOKEN, connToken); - - return await handler.proxyMode.custom.onProxyRequest( - c, - proxyRequest, - actorId, - ); - } else { - assertUnreachable(handler.proxyMode); - } - } catch (error) { - logger().error("error proxying connection message", { error }); - - // Use ProxyError if it's not already an ActorError - if (!errors.ActorError.isActorError(error)) { - throw new errors.ProxyError("connection message", error); - } else { - throw error; - } - } + throw "UNIMPLEMENTED"; + //logger().debug("connection message request received"); + //try { + // const params = ConnMessageRequestSchema.safeParse({ + // actorId: c.req.header(HEADER_ACTOR_ID), + // connId: c.req.header(HEADER_CONN_ID), + // encoding: c.req.header(HEADER_ENCODING), + // connToken: c.req.header(HEADER_CONN_TOKEN), + // }); + // if (!params.success) { + // logger().error("invalid connection parameters", { + // error: params.error, + // }); + // throw new errors.InvalidRequest(params.error); + // } + // const { actorId, connId, encoding, connToken } = params.data; + // + // // Handle based on mode + // if ("inline" in handler.proxyMode) { + // logger().debug("using inline proxy mode for connection message"); + // // Use shared connection message handler with direct parameters + // return handleConnectionMessage( + // c, + // appConfig, + // handler.proxyMode.inline.handlers.onConnMessage, + // connId, + // connToken as string, + // actorId, + // ); + // } else if ("custom" in handler.proxyMode) { + // logger().debug("using custom proxy mode for connection message"); + // const url = new URL(`http://actor/connections/message`); + // + // const proxyRequest = new Request(url, c.req.raw); + // proxyRequest.headers.set(HEADER_ENCODING, encoding); + // proxyRequest.headers.set(HEADER_CONN_ID, connId); + // proxyRequest.headers.set(HEADER_CONN_TOKEN, connToken); + // + // return await handler.proxyMode.custom.onProxyRequest( + // c, + // proxyRequest, + // actorId, + // ); + // } else { + // assertUnreachable(handler.proxyMode); + // } + //} catch (error) { + // logger().error("error proxying connection message", { error }); + // + // // Use ProxyError if it's not already an ActorError + // if (!errors.ActorError.isActorError(error)) { + // throw new errors.ProxyError("connection message", error); + // } else { + // throw error; + // } + //} } /** @@ -763,6 +697,7 @@ async function handleActionRequest( appConfig: AppConfig, driverConfig: DriverConfig, driver: ManagerDriver, + clientDriver: ClientDriver, handler: ManagerRouterHandler, ): Promise { try { @@ -782,37 +717,76 @@ async function handleActionRequest( throw new errors.InvalidRequest(params.error); } - // Get the actor ID and meta - const { actorId, meta } = await queryActor(c, params.data.query, driver); - logger().debug("found actor for action", { actorId, meta }); - invariant(actorId, "Missing actor ID"); + const encoding = getRequestEncoding(c.req, false); + const parameters = getRequestConnParams(c.req, appConfig, driverConfig); - // Handle based on mode - if ("inline" in handler.proxyMode) { - logger().debug("using inline proxy mode for action call"); - // Use shared action handler with direct parameter - return handleAction( - c, - appConfig, - driverConfig, - handler.proxyMode.inline.handlers.onAction, - actionName, - actorId, - ); - } else if ("custom" in handler.proxyMode) { - logger().debug("using custom proxy mode for action call"); - const url = new URL( - `http://actor/action/${encodeURIComponent(actionName)}`, - ); - const proxyRequest = new Request(url, c.req.raw); - return await handler.proxyMode.custom.onProxyRequest( - c, - proxyRequest, - actorId, - meta, - ); + // Validate incoming request + let actionArgs: unknown[]; + if (encoding === "json") { + try { + actionArgs = await c.req.json(); + } catch (err) { + throw new errors.InvalidActionRequest("Invalid JSON"); + } + + if (!Array.isArray(actionArgs)) { + throw new errors.InvalidActionRequest( + "Action arguments must be an array", + ); + } + } else if (encoding === "cbor") { + try { + const value = await c.req.arrayBuffer(); + const uint8Array = new Uint8Array(value); + const deserialized = await deserialize( + uint8Array as unknown as InputData, + encoding, + ); + + // Validate using the action schema + const result = + protoHttpAction.ActionRequestSchema.safeParse(deserialized); + if (!result.success) { + throw new errors.InvalidActionRequest( + "Invalid action request format", + ); + } + + actionArgs = result.data.a; + } catch (err) { + throw new errors.InvalidActionRequest( + `Invalid binary format: ${stringifyError(err)}`, + ); + } + } else { + return assertUnreachable(encoding); + } + + // Call action + const output = await clientDriver.action( + c.req, + params.data.query, + encoding, + parameters, + actionName, + ...actionArgs, + ); + + // Encode the response + if (encoding === "json") { + return c.json(output as Record); + } else if (encoding === "cbor") { + // Use serialize from serde.ts instead of custom encoder + const responseData = { + o: output, // Use the format expected by ResponseOkSchema + }; + const serialized = serialize(responseData, encoding); + + return c.body(serialized as Uint8Array, 200, { + "Content-Type": "application/octet-stream", + }); } else { - assertUnreachable(handler.proxyMode); + return assertUnreachable(encoding); } } catch (error) { logger().error("error in action handler", { error }); @@ -831,7 +805,7 @@ async function handleActionRequest( */ async function handleResolveRequest( c: HonoContext, - driver: ManagerDriver, + clientDriver: ClientDriver, ): Promise { const encoding = getRequestEncoding(c.req, false); logger().debug("resolve request encoding", { encoding }); @@ -846,10 +820,11 @@ async function handleResolveRequest( throw new errors.InvalidRequest(params.error); } - // Get the actor ID and meta - const { actorId, meta } = await queryActor(c, params.data.query, driver); - logger().debug("resolved actor", { actorId, meta }); - invariant(actorId, "Missing actor ID"); + const actorId = await clientDriver.resolveActorId( + c.req, + params.data.query, + encoding, + ); // Format response according to protocol const response: protoHttpResolve.ResolveResponse = { diff --git a/packages/actor/src/topologies/coordinate/router/sse.ts b/packages/actor/src/topologies/coordinate/router/sse.ts index 456c3e220..712a410a7 100644 --- a/packages/actor/src/topologies/coordinate/router/sse.ts +++ b/packages/actor/src/topologies/coordinate/router/sse.ts @@ -4,9 +4,9 @@ import { encodeDataToString, serialize } from "@/actor/protocol/serde"; import type { CoordinateDriver } from "../driver"; import { RelayConn } from "../conn/mod"; import type { ActorDriver } from "@/actor/driver"; -import type { ConnectSseOpts, ConnectSseOutput } from "@/actor/router"; import { DriverConfig } from "@/driver-helpers/config"; import { AppConfig } from "@/app/config"; +import { ConnectSseOpts, ConnectSseOutput } from "@/actor/router-endpoints"; export async function serveSse( appConfig: AppConfig, @@ -15,7 +15,7 @@ export async function serveSse( CoordinateDriver: CoordinateDriver, globalState: GlobalState, actorId: string, - { encoding, params: params }: ConnectSseOpts, + { encoding, params }: ConnectSseOpts, ): Promise { let conn: RelayConn | undefined; return { diff --git a/packages/actor/src/topologies/coordinate/router/websocket.ts b/packages/actor/src/topologies/coordinate/router/websocket.ts index a5c7fb15d..ba8208b54 100644 --- a/packages/actor/src/topologies/coordinate/router/websocket.ts +++ b/packages/actor/src/topologies/coordinate/router/websocket.ts @@ -8,9 +8,9 @@ import type { CoordinateDriver } from "../driver"; import { RelayConn } from "../conn/mod"; import { publishMessageToLeader } from "../node/message"; import type { ActorDriver } from "@/actor/driver"; -import type { ConnectWebSocketOpts, ConnectWebSocketOutput } from "@/actor/router"; -import { DriverConfig } from "@/driver-helpers/config"; -import { AppConfig } from "@/app/config"; +import type { DriverConfig } from "@/driver-helpers/config"; +import type { AppConfig } from "@/app/config"; +import { ConnectWebSocketOpts, ConnectWebSocketOutput } from "@/actor/router-endpoints"; export async function serveWebSocket( appConfig: AppConfig, diff --git a/packages/actor/src/topologies/coordinate/topology.ts b/packages/actor/src/topologies/coordinate/topology.ts index 09b2b6a9b..abd53ad08 100644 --- a/packages/actor/src/topologies/coordinate/topology.ts +++ b/packages/actor/src/topologies/coordinate/topology.ts @@ -1,5 +1,3 @@ -import { serveSse } from "./router/sse"; -import { serveWebSocket } from "./router/websocket"; import { Node } from "./node/mod"; import type { ActorPeer } from "./actor-peer"; import * as errors from "@/actor/errors"; @@ -7,7 +5,6 @@ import * as events from "node:events"; import { publishMessageToLeader } from "./node/message"; import type { RelayConn } from "./conn/mod"; import { Hono } from "hono"; -import { createActorRouter } from "@/actor/router"; import { handleRouteError, handleRouteNotFound } from "@/common/router"; import type { DriverConfig } from "@/driver-helpers/config"; import type { AppConfig } from "@/app/config"; @@ -22,6 +19,10 @@ import type { ActionOutput, ConnectionHandlers, } from "@/actor/router-endpoints"; +import invariant from "invariant"; +import { createInlineClientDriver } from "@/app/inline-client-driver"; +import { serveWebSocket } from "./router/websocket"; +import { serveSse } from "./router/sse"; export interface GlobalState { nodeId: string; @@ -116,16 +117,22 @@ export class CoordinateTopology { }; // Build manager router - const managerRouter = createManagerRouter(appConfig, driverConfig, { - proxyMode: { - inline: { - handlers: connectionHandlers, + const managerDriver = driverConfig.drivers.manager; + invariant(managerDriver, "missing manager driver"); + const clientDriver = createInlineClientDriver( + managerDriver, + connectionHandlers, + ); + const managerRouter = createManagerRouter( + appConfig, + driverConfig, + clientDriver, + { + onConnectInspector: () => { + throw new errors.Unsupported("inspect"); }, }, - onConnectInspector: () => { - throw new errors.Unsupported("inspect"); - }, - }); + ); app.route("/", managerRouter); diff --git a/packages/actor/src/actor/router.ts b/packages/actor/src/topologies/partition/actor-router.ts similarity index 96% rename from packages/actor/src/actor/router.ts rename to packages/actor/src/topologies/partition/actor-router.ts index f599854ea..cbb06194f 100644 --- a/packages/actor/src/actor/router.ts +++ b/packages/actor/src/topologies/partition/actor-router.ts @@ -52,9 +52,7 @@ export interface ActorRouterHandler { } /** - * Creates a router that handles requests for the protocol and passes it off to the handler. - * - * This allows for creating a universal protocol across all platforms. + * Creates a router that runs on the partitioned instance. */ export function createActorRouter( appConfig: AppConfig, diff --git a/packages/actor/src/topologies/partition/toplogy.ts b/packages/actor/src/topologies/partition/toplogy.ts index 78cca6622..f7267995a 100644 --- a/packages/actor/src/topologies/partition/toplogy.ts +++ b/packages/actor/src/topologies/partition/toplogy.ts @@ -1,5 +1,4 @@ import { Hono } from "hono"; -import { createActorRouter } from "@/actor/router"; import type { AnyActorInstance } from "@/actor/instance"; import * as errors from "@/actor/errors"; import { @@ -24,11 +23,7 @@ import type { ActorKey } from "@/common/utils"; import type { DriverConfig } from "@/driver-helpers/config"; import type { AppConfig } from "@/app/config"; import type { ActorInspectorConnection } from "@/inspector/actor"; -import { - createManagerRouter, - OnProxyWebSocket, - type OnProxyRequest, -} from "@/manager/router"; +import { createManagerRouter } from "@/manager/router"; import type { ManagerInspectorConnection } from "@/inspector/manager"; import type { ConnectWebSocketOpts, @@ -39,6 +34,22 @@ import type { ConnectSseOutput, ActionOutput, } from "@/actor/router-endpoints"; +import { SendRequestHandler, ProxyWebSocketHandler } from "@/app/inline-client-driver"; +import { ClientDriver } from "@/client/client"; + +export type SendRequestHandler = ( + c: HonoContext, + actorRequest: Request, + actorId: string, + meta?: unknown, +) => Promise; + +export type OpenWebSocketHandler = ( + c: HonoContext, + path: string, + actorId: string, + meta?: unknown, +) => Promise; export class PartitionTopologyManager { router: Hono; @@ -47,14 +58,28 @@ export class PartitionTopologyManager { appConfig: AppConfig, driverConfig: DriverConfig, proxyCustomConfig: { - onProxyRequest: OnProxyRequest; - onProxyWebSocket: OnProxyWebSocket; + sendRequest: OnSendRequest; + openWebSocket: OnOpenWebSocket; + proxyRequest: SendRequestHandler; + proxyWebSocket: ProxyWebSocketHandler; }, ) { - this.router = createManagerRouter(appConfig, driverConfig, { - proxyMode: { - custom: proxyCustomConfig, - }, + function unimplemented(): never { + throw new Error("UNIMPLEMENTED"); + } + + // TODO: needs a custom client driver that will forward to the actor + const clientDriver: ClientDriver = { + action: unimplemented, + resolveActorId: unimplemented, + connectWebSocket: unimplemented, + connectSse: unimplemented, + sendHttpMessage: unimplemented, + }; + + this.router = createManagerRouter(appConfig, driverConfig, clientDriver, { + proxyRequest, + proxyWebSocket, onConnectInspector: async () => { const inspector = driverConfig.drivers?.manager?.inspector; if (!inspector) throw new errors.Unsupported("inspector"); @@ -85,8 +110,6 @@ export class PartitionTopologyManager { /** Manages the actor in the topology. */ export class PartitionTopologyActor { - router: Hono; - #appConfig: AppConfig; #driverConfig: DriverConfig; #connDrivers: Record; @@ -110,189 +133,189 @@ export class PartitionTopologyActor { this.#connDrivers = createGenericConnDrivers(genericConnGlobalState); // TODO: Store this actor router globally so we're not re-initializing it for every DO - this.router = createActorRouter(appConfig, driverConfig, { - getActorId: async () => { - if (this.#actorStartedPromise) await this.#actorStartedPromise.promise; - return this.actor.id; - }, - connectionHandlers: { - onConnectWebSocket: async ( - opts: ConnectWebSocketOpts, - ): Promise => { - if (this.#actorStartedPromise) - await this.#actorStartedPromise.promise; - - const actor = this.#actor; - if (!actor) throw new Error("Actor should be defined"); - - const connId = generateConnId(); - const connToken = generateConnToken(); - const connState = await actor.prepareConn(opts.params, opts.req.raw); - - let conn: AnyConn | undefined; - return { - onOpen: async (ws) => { - // Save socket - genericConnGlobalState.websockets.set(connId, ws); - - // Create connection - conn = await actor.createConn( - connId, - connToken, - opts.params, - connState, - CONN_DRIVER_GENERIC_WEBSOCKET, - { - encoding: opts.encoding, - } satisfies GenericWebSocketDriverState, - ); - }, - onMessage: async (message) => { - logger().debug("received message"); - - if (!conn) { - logger().warn("`conn` does not exist"); - return; - } - - await actor.processMessage(message, conn); - }, - onClose: async () => { - genericConnGlobalState.websockets.delete(connId); - - if (conn) { - actor.__removeConn(conn); - } - }, - }; - }, - onConnectSse: async ( - opts: ConnectSseOpts, - ): Promise => { - if (this.#actorStartedPromise) - await this.#actorStartedPromise.promise; - - const actor = this.#actor; - if (!actor) throw new Error("Actor should be defined"); - - const connId = generateConnId(); - const connToken = generateConnToken(); - const connState = await actor.prepareConn(opts.params, opts.req.raw); - - let conn: AnyConn | undefined; - return { - onOpen: async (stream) => { - // Save socket - genericConnGlobalState.sseStreams.set(connId, stream); - - // Create connection - conn = await actor.createConn( - connId, - connToken, - opts.params, - connState, - CONN_DRIVER_GENERIC_SSE, - { encoding: opts.encoding } satisfies GenericSseDriverState, - ); - }, - onClose: async () => { - genericConnGlobalState.sseStreams.delete(connId); - - if (conn) { - actor.__removeConn(conn); - } - }, - }; - }, - onAction: async (opts: ActionOpts): Promise => { - let conn: AnyConn | undefined; - try { - // Wait for init to finish - if (this.#actorStartedPromise) - await this.#actorStartedPromise.promise; - - const actor = this.#actor; - if (!actor) throw new Error("Actor should be defined"); - - // Create conn - const connState = await actor.prepareConn( - opts.params, - opts.req.raw, - ); - conn = await actor.createConn( - generateConnId(), - generateConnToken(), - opts.params, - connState, - CONN_DRIVER_GENERIC_HTTP, - {} satisfies GenericHttpDriverState, - ); - - // Call action - const ctx = new ActionContext(actor.actorContext!, conn!); - const output = await actor.executeAction( - ctx, - opts.actionName, - opts.actionArgs, - ); - - return { output }; - } finally { - if (conn) { - this.#actor?.__removeConn(conn); - } - } - }, - onConnMessage: async (opts: ConnsMessageOpts): Promise => { - // Wait for init to finish - if (this.#actorStartedPromise) - await this.#actorStartedPromise.promise; - - const actor = this.#actor; - if (!actor) throw new Error("Actor should be defined"); - - // Find connection - const conn = actor.conns.get(opts.connId); - if (!conn) { - throw new errors.ConnNotFound(opts.connId); - } - - // Authenticate connection - if (conn._token !== opts.connToken) { - throw new errors.IncorrectConnToken(); - } - - // Process message - await actor.processMessage(opts.message, conn); - }, - }, - onConnectInspector: async () => { - if (this.#actorStartedPromise) await this.#actorStartedPromise.promise; - - const actor = this.#actor; - if (!actor) throw new Error("Actor should be defined"); - - let conn: ActorInspectorConnection | undefined; - return { - onOpen: async (ws) => { - conn = actor.inspector.createConnection(ws); - }, - onMessage: async (message) => { - if (!conn) { - logger().warn("`conn` does not exist"); - return; - } - - actor.inspector.processMessage(conn, message); - }, - onClose: async () => { - if (conn) { - actor.inspector.removeConnection(conn); - } - }, - }; - }, - }); + //this.router = createActorRouter(appConfig, driverConfig, { + // getActorId: async () => { + // if (this.#actorStartedPromise) await this.#actorStartedPromise.promise; + // return this.actor.id; + // }, + // connectionHandlers: { + // onConnectWebSocket: async ( + // opts: ConnectWebSocketOpts, + // ): Promise => { + // if (this.#actorStartedPromise) + // await this.#actorStartedPromise.promise; + // + // const actor = this.#actor; + // if (!actor) throw new Error("Actor should be defined"); + // + // const connId = generateConnId(); + // const connToken = generateConnToken(); + // const connState = await actor.prepareConn(opts.params, opts.req.raw); + // + // let conn: AnyConn | undefined; + // return { + // onOpen: async (ws) => { + // // Save socket + // genericConnGlobalState.websockets.set(connId, ws); + // + // // Create connection + // conn = await actor.createConn( + // connId, + // connToken, + // opts.params, + // connState, + // CONN_DRIVER_GENERIC_WEBSOCKET, + // { + // encoding: opts.encoding, + // } satisfies GenericWebSocketDriverState, + // ); + // }, + // onMessage: async (message) => { + // logger().debug("received message"); + // + // if (!conn) { + // logger().warn("`conn` does not exist"); + // return; + // } + // + // await actor.processMessage(message, conn); + // }, + // onClose: async () => { + // genericConnGlobalState.websockets.delete(connId); + // + // if (conn) { + // actor.__removeConn(conn); + // } + // }, + // }; + // }, + // onConnectSse: async ( + // opts: ConnectSseOpts, + // ): Promise => { + // if (this.#actorStartedPromise) + // await this.#actorStartedPromise.promise; + // + // const actor = this.#actor; + // if (!actor) throw new Error("Actor should be defined"); + // + // const connId = generateConnId(); + // const connToken = generateConnToken(); + // const connState = await actor.prepareConn(opts.params, opts.req.raw); + // + // let conn: AnyConn | undefined; + // return { + // onOpen: async (stream) => { + // // Save socket + // genericConnGlobalState.sseStreams.set(connId, stream); + // + // // Create connection + // conn = await actor.createConn( + // connId, + // connToken, + // opts.params, + // connState, + // CONN_DRIVER_GENERIC_SSE, + // { encoding: opts.encoding } satisfies GenericSseDriverState, + // ); + // }, + // onClose: async () => { + // genericConnGlobalState.sseStreams.delete(connId); + // + // if (conn) { + // actor.__removeConn(conn); + // } + // }, + // }; + // }, + // onAction: async (opts: ActionOpts): Promise => { + // let conn: AnyConn | undefined; + // try { + // // Wait for init to finish + // if (this.#actorStartedPromise) + // await this.#actorStartedPromise.promise; + // + // const actor = this.#actor; + // if (!actor) throw new Error("Actor should be defined"); + // + // // Create conn + // const connState = await actor.prepareConn( + // opts.params, + // opts.req.raw, + // ); + // conn = await actor.createConn( + // generateConnId(), + // generateConnToken(), + // opts.params, + // connState, + // CONN_DRIVER_GENERIC_HTTP, + // {} satisfies GenericHttpDriverState, + // ); + // + // // Call action + // const ctx = new ActionContext(actor.actorContext!, conn!); + // const output = await actor.executeAction( + // ctx, + // opts.actionName, + // opts.actionArgs, + // ); + // + // return { output }; + // } finally { + // if (conn) { + // this.#actor?.__removeConn(conn); + // } + // } + // }, + // onConnMessage: async (opts: ConnsMessageOpts): Promise => { + // // Wait for init to finish + // if (this.#actorStartedPromise) + // await this.#actorStartedPromise.promise; + // + // const actor = this.#actor; + // if (!actor) throw new Error("Actor should be defined"); + // + // // Find connection + // const conn = actor.conns.get(opts.connId); + // if (!conn) { + // throw new errors.ConnNotFound(opts.connId); + // } + // + // // Authenticate connection + // if (conn._token !== opts.connToken) { + // throw new errors.IncorrectConnToken(); + // } + // + // // Process message + // await actor.processMessage(opts.message, conn); + // }, + // }, + // onConnectInspector: async () => { + // if (this.#actorStartedPromise) await this.#actorStartedPromise.promise; + // + // const actor = this.#actor; + // if (!actor) throw new Error("Actor should be defined"); + // + // let conn: ActorInspectorConnection | undefined; + // return { + // onOpen: async (ws) => { + // conn = actor.inspector.createConnection(ws); + // }, + // onMessage: async (message) => { + // if (!conn) { + // logger().warn("`conn` does not exist"); + // return; + // } + // + // actor.inspector.processMessage(conn, message); + // }, + // onClose: async () => { + // if (conn) { + // actor.inspector.removeConnection(conn); + // } + // }, + // }; + // }, + //}); } async start(id: string, name: string, key: ActorKey, region: string) { diff --git a/packages/actor/src/topologies/standalone/topology.ts b/packages/actor/src/topologies/standalone/topology.ts index dc6e824df..0e09799b9 100644 --- a/packages/actor/src/topologies/standalone/topology.ts +++ b/packages/actor/src/topologies/standalone/topology.ts @@ -32,6 +32,8 @@ import type { ActionOutput, ConnectionHandlers, } from "@/actor/router-endpoints"; +import { createInlineClientDriver } from "@/app/inline-client-driver"; +import invariant from "invariant"; class ActorHandler { /** Will be undefined if not yet loaded. */ @@ -214,7 +216,9 @@ export class StandaloneTopology { const { actor } = await this.#getActor(opts.actorId); // Create conn - const connState = await actor.prepareConn(opts.params, opts.req.raw); + const req = opts.req; + invariant(req, "missing request") + const connState = await actor.prepareConn(opts.params, req.raw); conn = await actor.createConn( generateConnId(), generateConnToken(), @@ -260,37 +264,43 @@ export class StandaloneTopology { }; // Build manager router - const managerRouter = createManagerRouter(appConfig, driverConfig, { - proxyMode: { - inline: { - handlers: sharedConnectionHandlers, + const managerDriver = this.#driverConfig.drivers.manager; + invariant(managerDriver, "missing manager driver"); + const clientDriver = createInlineClientDriver( + managerDriver, + sharedConnectionHandlers, + ); + const managerRouter = createManagerRouter( + appConfig, + driverConfig, + clientDriver, + { + onConnectInspector: async () => { + const inspector = driverConfig.drivers?.manager?.inspector; + if (!inspector) throw new errors.Unsupported("inspector"); + + let conn: ManagerInspectorConnection | undefined; + return { + onOpen: async (ws) => { + conn = inspector.createConnection(ws); + }, + onMessage: async (message) => { + if (!conn) { + logger().warn("`conn` does not exist"); + return; + } + + inspector.processMessage(conn, message); + }, + onClose: async () => { + if (conn) { + inspector.removeConnection(conn); + } + }, + }; }, }, - onConnectInspector: async () => { - const inspector = driverConfig.drivers?.manager?.inspector; - if (!inspector) throw new errors.Unsupported("inspector"); - - let conn: ManagerInspectorConnection | undefined; - return { - onOpen: async (ws) => { - conn = inspector.createConnection(ws); - }, - onMessage: async (message) => { - if (!conn) { - logger().warn("`conn` does not exist"); - return; - } - - inspector.processMessage(conn, message); - }, - onClose: async () => { - if (conn) { - inspector.removeConnection(conn); - } - }, - }; - }, - }); + ); app.route("/", managerRouter); diff --git a/packages/actor/src/utils.ts b/packages/actor/src/utils.ts index 023b4d2cb..16f3f860b 100644 --- a/packages/actor/src/utils.ts +++ b/packages/actor/src/utils.ts @@ -14,7 +14,7 @@ export function httpUserAgent(): string { } // Library - let userAgent = `ActorCore/${VERSION}`; + let userAgent = `RivetKit/${VERSION}`; // Navigator const navigatorObj = typeof navigator !== "undefined" ? navigator : undefined; diff --git a/packages/platforms/rivet/src/manager-driver.ts b/packages/platforms/rivet/src/manager-driver.ts index d3d292bea..ef3bc79dd 100644 --- a/packages/platforms/rivet/src/manager-driver.ts +++ b/packages/platforms/rivet/src/manager-driver.ts @@ -49,8 +49,8 @@ export class RivetManagerDriver implements ManagerDriver { if (res.actor.tags.role !== "actor") { throw new Error(`Actor ${res.actor.id} does not have an actor role.`); } - if (res.actor.tags.framework !== "@rivetkit/actor") { - throw new Error(`Actor ${res.actor.id} is not an ActorCore actor.`); + if (res.actor.tags.framework !== "rivetkit") { + throw new Error(`Actor ${res.actor.id} is not RivetKit actor.`); } return { @@ -153,7 +153,7 @@ export class RivetManagerDriver implements ManagerDriver { build_tags: { name, role: "actor", - framework: "@rivetkit/actor", + framework: "rivetkit", current: "true", }, region, @@ -235,7 +235,7 @@ export class RivetManagerDriver implements ManagerDriver { name, key: serializeKeyForTag(key), role: "actor", - framework: "@rivetkit/actor", + framework: "rivetkit", }; }