Skip to content

Commit 6d4c516

Browse files
feat: Add ctx.agents.react for multi-step agent functionality (#854)
* feat: Add ctx.agents.react for multi-step agent functionality * Update sdk-node/src/workflows/workflow.ts Co-authored-by: John Smith <john@johnjcsmith.com> * fix: Correct variable name from instruction to instructions across modules and update package dependency version * chore: Update package-lock.json * feat: Enhance AI message handling with previous attempt tracking and parameter refinement --------- Co-authored-by: John Smith <john@johnjcsmith.com>
1 parent 8dad0f7 commit 6d4c516

File tree

11 files changed

+338
-89
lines changed

11 files changed

+338
-89
lines changed

app/client/contract.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1458,7 +1458,7 @@ export const definition = {
14581458
}),
14591459
body: z.object({
14601460
input: z.string(),
1461-
instruction: z.string().optional(),
1461+
instructions: z.string().optional(),
14621462
schema: z.record(z.any()),
14631463
}),
14641464
headers: z.object({

control-plane/package-lock.json

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

control-plane/package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"@e2b/code-interpreter": "^1.0.4",
2828
"@fastify/cors": "^8.5.0",
2929
"@hyperdx/node-opentelemetry": "^0.8.1",
30-
"@l1m/core": "^0.1.3",
30+
"@l1m/core": "^0.1.5",
3131
"@langchain/cohere": "^0.3.1",
3232
"@langchain/langgraph": "^0.1.9",
3333
"@nangohq/node": "^0.48.1",
@@ -97,4 +97,4 @@
9797
"ts-node": "^10.9.2",
9898
"tsx": "^4.7.0"
9999
}
100-
}
100+
}

control-plane/src/modules/contract.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1458,7 +1458,7 @@ export const definition = {
14581458
}),
14591459
body: z.object({
14601460
input: z.string(),
1461-
instruction: z.string().optional(),
1461+
instructions: z.string().optional(),
14621462
schema: z.record(z.any()),
14631463
}),
14641464
headers: z.object({

control-plane/src/modules/router.ts

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,9 +1566,9 @@ export const router = initServer().router(contract, {
15661566
status: 200,
15671567
body: tools,
15681568
};
1569-
},
1569+
},
15701570
l1mStructured: async request => {
1571-
const { input, instruction, schema } = request.body;
1571+
const { input, instructions, schema } = request.body;
15721572
const { clusterId } = request.params;
15731573

15741574
const auth = request.request.getAuth();
@@ -1579,7 +1579,7 @@ export const router = initServer().router(contract, {
15791579
const providerUrl = request.headers["x-provider-url"];
15801580

15811581
const executionId = request.headers["x-workflow-execution-id"];
1582-
const maxAttempts = request.headers["x-max-attempts"]
1582+
const maxAttempts = request.headers["x-max-attempts"];
15831583

15841584
if (!executionId) {
15851585
return {
@@ -1591,14 +1591,14 @@ export const router = initServer().router(contract, {
15911591
}
15921592

15931593
const hash = crypto.createHash("sha256");
1594-
hash.update(input)
1595-
hash.update(JSON.stringify(schema))
1596-
hash.update(providerModel)
1597-
hash.update(providerKey)
1598-
hash.update(executionId)
1599-
instruction && hash.update(instruction)
1594+
hash.update(input);
1595+
hash.update(JSON.stringify(schema));
1596+
hash.update(providerModel);
1597+
hash.update(providerKey);
1598+
hash.update(executionId);
1599+
instructions && hash.update(instructions);
16001600

1601-
const messageKey = `${executionId}_structured_${hash.digest("hex")}`
1601+
const messageKey = `${executionId}_structured_${hash.digest("hex")}`;
16021602

16031603
const existingMessage = await kv.get(clusterId, messageKey);
16041604
if (existingMessage) {
@@ -1636,7 +1636,7 @@ export const router = initServer().router(contract, {
16361636
url: providerUrl,
16371637
key: providerKey,
16381638
model: providerModel,
1639-
}
1639+
};
16401640

16411641
if (providerUrl.includes("inferable") || providerUrl === "") {
16421642
if (!["claude-3-5-sonnet", "claude-3-haiku"].includes(providerModel)) {
@@ -1652,7 +1652,7 @@ export const router = initServer().router(contract, {
16521652
identifier: providerModel as any,
16531653
trackingOptions: {
16541654
clusterId: clusterId,
1655-
}
1655+
},
16561656
});
16571657

16581658
provider = async (params, prompt, previousAttempts) => {
@@ -1664,7 +1664,7 @@ export const router = initServer().router(contract, {
16641664
messages.push({
16651665
role: "user",
16661666
content: [
1667-
{ type: "text", text: `${instruction} ${prompt}` },
1667+
{ type: "text", text: `${instructions} ${prompt}` },
16681668
{
16691669
type: "image",
16701670
source: {
@@ -1678,15 +1678,19 @@ export const router = initServer().router(contract, {
16781678
} else {
16791679
messages.push({
16801680
role: "user",
1681-
content: `${input} ${instruction} ${prompt}`,
1681+
content: `${input} ${instructions} ${prompt}`,
16821682
});
16831683
}
16841684

16851685
if (previousAttempts.length > 0) {
1686-
previousAttempts.forEach((attempt) => {
1686+
previousAttempts.forEach(attempt => {
16871687
messages.push({
16881688
role: "user",
1689-
content: "You previously responded: " + attempt.raw + " which produced validation errors: " + attempt.errors,
1689+
content:
1690+
"You previously responded: " +
1691+
attempt.raw +
1692+
" which produced validation errors: " +
1693+
attempt.errors,
16901694
});
16911695
});
16921696
}
@@ -1700,17 +1704,17 @@ export const router = initServer().router(contract, {
17001704
} else {
17011705
throw new Error("Anthropic API returned invalid response");
17021706
}
1703-
}
1707+
};
17041708
}
17051709

17061710
const result = await structured({
17071711
input,
17081712
type,
17091713
schema,
17101714
maxAttempts: maxAttempts ? parseInt(maxAttempts) : 3,
1711-
instruction,
1715+
instructions,
17121716
provider,
1713-
})
1717+
});
17141718

17151719
if (!result.valid || !result.structured) {
17161720
return {

sdk-node/package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"description": "Javascript SDK for inferable.ai",
55
"main": "bin/index.js",
66
"scripts": {
7-
"build": "tsc",
7+
"build": "tsc -p tsconfig.build.json",
88
"clean": "rm -rf ./bin",
99
"prepare": "husky",
1010
"test": "jest ./src --runInBand --forceExit --setupFiles dotenv/config",
@@ -56,4 +56,4 @@
5656
"email": "hi@inferable.ai",
5757
"url": "https://github.com/inferablehq/inferable/issues"
5858
}
59-
}
59+
}

sdk-node/src/types.ts

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,20 @@ export type ToolSchema<T extends z.ZodTypeAny | JsonSchemaInput> = {
7878
input: T;
7979
};
8080

81-
export type ToolRegistrationInput<
82-
T extends z.ZodTypeAny | JsonSchemaInput,
83-
> = {
81+
export type ToolRegistrationInput<T extends z.ZodTypeAny | JsonSchemaInput> = {
8482
name: string;
8583
// eslint-disable-next-line @typescript-eslint/no-explicit-any
86-
func: (input: ToolInput<T>, context: JobContext) => any;
84+
func: (input: ToolInput<T>, context: JobContext) => Promise<any>;
8785
schema?: ToolSchema<T>;
8886
config?: ToolConfig;
8987
description?: string;
9088
};
89+
90+
export type WorkflowToolRegistrationInput<
91+
T extends z.ZodTypeAny | JsonSchemaInput,
92+
> = {
93+
name: string;
94+
inputSchema?: T;
95+
config?: ToolConfig;
96+
func: (input: ToolInput<T>, context: JobContext) => Promise<unknown>;
97+
};

sdk-node/src/workflows/workflow.test.ts

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { z } from "zod";
22
import { helpers } from "./workflow";
33
import { inferableInstance } from "../tests/utils";
4+
import assert from "assert";
45

56
describe("workflow", () => {
67
jest.setTimeout(60_000);
@@ -12,18 +13,6 @@ describe("workflow", () => {
1213
const onSimpleResult = jest.fn();
1314
const toolCall = jest.fn();
1415

15-
inferable.tools.register({
16-
func: (_i, _c) => {
17-
toolCall();
18-
return {
19-
word: "needle",
20-
};
21-
},
22-
name: "searchHaystack",
23-
});
24-
25-
inferable.tools.listen();
26-
2716
// Generate a unique workflow name to prevent conflicts with other tests
2817
const workflowName = `haystack-search-${Math.random().toString(36).substring(2, 15)}`;
2918

@@ -35,59 +24,88 @@ describe("workflow", () => {
3524
}),
3625
});
3726

27+
workflow.tools.register({
28+
name: "searchHaystack2",
29+
inputSchema: z.object({
30+
searchQuery: z.string(),
31+
}),
32+
func: async (input) => {
33+
toolCall(input);
34+
if (input.searchQuery === "marco") {
35+
return { word: "not-needle" };
36+
} else if (input.searchQuery === "marco 42") {
37+
return { word: "needle" };
38+
} else {
39+
return { word: `not-found-${input.searchQuery}` };
40+
}
41+
},
42+
});
43+
3844
workflow.version(1).define(async (ctx, input) => {
3945
onStart(input);
4046
ctx.log("info", { message: "Starting workflow" });
41-
const searchAgent = ctx.agent({
47+
const { word } = await ctx.agents.react({
4248
name: "search",
43-
tools: ["searchHaystack"],
44-
systemPrompt: helpers.structuredPrompt({
49+
instructions: helpers.structuredPrompt({
4550
facts: ["You are haystack searcher"],
46-
goals: ["Find the special word in the haystack"],
51+
goals: [
52+
"Find the special word in the haystack. Only search for the words asked explictly by the user.",
53+
],
4754
}),
48-
resultSchema: z.object({
55+
schema: z.object({
4956
word: z.string(),
5057
}),
58+
tools: ["searchHaystack2"],
59+
input: `Try the searchQuery 'marco'.`,
60+
onBeforeReturn: async (result, agent) => {
61+
if (result.word !== "needle") {
62+
await agent.sendMessage("Try the searchQuery 'marco 42'.");
63+
}
64+
},
5165
});
5266

53-
const result = await searchAgent.trigger({
54-
data: {},
55-
});
67+
assert(word === "needle", `Expected word to be "needle", got ${word}`);
5668

57-
ctx.result("testResultCall", async () => {
69+
const cachedResult = await ctx.result("testResultCall", async () => {
5870
return {
5971
word: "needle",
6072
};
6173
});
6274

63-
if (!result || !result.result || !result.result.word) {
64-
throw new Error("No result");
65-
}
75+
assert(
76+
cachedResult.word === "needle",
77+
`Expected cachedResult to be "needle", got ${cachedResult.word}`,
78+
);
6679

67-
onAgentResult(result.result.word);
80+
onAgentResult(cachedResult.word);
6881

6982
ctx.log("info", { message: "About to run simple LLM call" });
7083

7184
await ctx.llm.structured({
7285
input: "Return the word, needle.",
7386
schema: z.object({
7487
word: z.string(),
75-
})
88+
}),
7689
});
7790

7891
// Duplicate call
7992
const simpleResult = await ctx.llm.structured({
8093
input: "Return the word, needle.",
8194
schema: z.object({
8295
word: z.string(),
83-
})
96+
}),
8497
});
8598

86-
if (!simpleResult || !simpleResult.word) {
87-
throw new Error("No simpleResult");
88-
}
99+
assert(
100+
simpleResult.word === "needle",
101+
`Expected simpleResult to be "needle", got ${simpleResult.word}`,
102+
);
103+
89104
onSimpleResult(simpleResult.word);
90105

106+
return {
107+
word: "needle",
108+
};
91109
});
92110

93111
await workflow.listen();
@@ -114,7 +132,7 @@ describe("workflow", () => {
114132
expect(onAgentResult).toHaveBeenCalledWith("needle");
115133
expect(onAgentResult).toHaveBeenCalledTimes(1);
116134

117-
expect(toolCall).toHaveBeenCalledTimes(1);
135+
expect(toolCall).toHaveBeenCalledTimes(2);
118136

119137
expect(onSimpleResult).toHaveBeenCalledWith("needle");
120138
expect(onSimpleResult).toHaveBeenCalledTimes(1);

0 commit comments

Comments
 (0)