Skip to content

Commit ab6aec2

Browse files
committed
support tool calls
1 parent b0ddc2d commit ab6aec2

File tree

6 files changed

+156
-36
lines changed

6 files changed

+156
-36
lines changed

src/__tests__/Client.test.ts

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,15 @@ describe('client', () => {
7878
promptMessages: [
7979
{
8080
content: 'Hello, world!',
81-
role: PromptMessageRoleEnum.USER
81+
role: PromptMessageRoleEnum.USER,
82+
toolCalls: null,
83+
toolCallId: null
8284
},
8385
{
8486
content: 'Hi there {{name}}!',
85-
role: PromptMessageRoleEnum.ASSISTANT
87+
role: PromptMessageRoleEnum.ASSISTANT,
88+
toolCalls: null,
89+
toolCallId: null
8690
}
8791
],
8892
promptTools: []
@@ -121,11 +125,15 @@ describe('client', () => {
121125
promptMessages: [
122126
{
123127
content: 'Hello, world!',
124-
role: PromptMessageRoleEnum.USER
128+
role: PromptMessageRoleEnum.USER,
129+
toolCalls: null,
130+
toolCallId: null
125131
},
126132
{
127133
content: 'Hi there {{name}}!',
128-
role: PromptMessageRoleEnum.ASSISTANT
134+
role: PromptMessageRoleEnum.ASSISTANT,
135+
toolCalls: null,
136+
toolCallId: null
129137
}
130138
],
131139
promptTools: []
@@ -164,11 +172,15 @@ describe('client', () => {
164172
promptMessages: [
165173
{
166174
content: 'Hello, world!',
167-
role: PromptMessageRoleEnum.USER
175+
role: PromptMessageRoleEnum.USER,
176+
toolCalls: null,
177+
toolCallId: null
168178
},
169179
{
170180
content: 'Hi there!',
171-
role: PromptMessageRoleEnum.ASSISTANT
181+
role: PromptMessageRoleEnum.ASSISTANT,
182+
toolCalls: null,
183+
toolCallId: null
172184
}
173185
],
174186
promptTools: []
@@ -207,11 +219,15 @@ describe('client', () => {
207219
promptMessages: [
208220
{
209221
content: 'Hello, world!',
210-
role: PromptMessageRoleEnum.USER
222+
role: PromptMessageRoleEnum.USER,
223+
toolCalls: null,
224+
toolCallId: null
211225
},
212226
{
213227
content: 'Hi there {{name}}!',
214-
role: PromptMessageRoleEnum.ASSISTANT
228+
role: PromptMessageRoleEnum.ASSISTANT,
229+
toolCalls: null,
230+
toolCallId: null
215231
}
216232
],
217233
promptTools: []
@@ -250,11 +266,15 @@ describe('client', () => {
250266
promptMessages: [
251267
{
252268
content: 'Hello, world!',
253-
role: PromptMessageRoleEnum.USER
269+
role: PromptMessageRoleEnum.USER,
270+
toolCalls: null,
271+
toolCallId: null
254272
},
255273
{
256274
content: 'Hi there {{name}}!',
257-
role: PromptMessageRoleEnum.ASSISTANT
275+
role: PromptMessageRoleEnum.ASSISTANT,
276+
toolCalls: null,
277+
toolCallId: null
258278
}
259279
],
260280
promptTools: []

src/__tests__/__snapshots__/Client.test.ts.snap

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@ exports[`client getOpenAiPrompt returnsmapped prompt 1`] = `
77
"messages": [
88
{
99
"content": "Hello, world!",
10+
"name": undefined,
1011
"role": "user",
1112
},
1213
{
1314
"content": "Hi there!",
15+
"name": undefined,
1416
"role": "assistant",
17+
"tool_calls": undefined,
1518
},
1619
],
1720
"model": "gpt-3.5-turbo",

src/helpers/__tests__/template.test.ts

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,30 +62,30 @@ describe('template helpers', () => {
6262

6363
describe('extractVariablesFromMessages', () => {
6464
it('returns an array of template variables from messages', () => {
65-
const messages = [{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{name}}' }]
65+
const messages = [{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{name}}', toolCalls: null, toolCallId: null }]
6666
expect(extractVariablesFromMessages(messages)).toEqual(['name'])
6767
})
6868

6969
it('returns an array of template variables from messages - variable across mutliple messages', () => {
7070
const messages = [
71-
{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{name}}' },
72-
{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{name}}' }
71+
{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{name}}', toolCalls: null, toolCallId: null },
72+
{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{name}}', toolCalls: null, toolCallId: null }
7373
]
7474
expect(extractVariablesFromMessages(messages)).toEqual(['name'])
7575
})
7676

7777
it('returns an array of template variables from messages - no variables', () => {
7878
const messages = [
79-
{ role: PromptMessageRoleEnum.USER, content: 'Hello' },
80-
{ role: PromptMessageRoleEnum.USER, content: 'Hello' }
79+
{ role: PromptMessageRoleEnum.USER, content: 'Hello', toolCalls: null, toolCallId: null },
80+
{ role: PromptMessageRoleEnum.USER, content: 'Hello', toolCalls: null, toolCallId: null }
8181
]
8282
expect(extractVariablesFromMessages(messages)).toEqual([])
8383
})
8484

8585
it('returns an array of template variables from messages - mutliple variables', () => {
8686
const messages = [
87-
{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{name}}' },
88-
{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{firstName }} {{lastName}}' }
87+
{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{name}}', toolCalls: null, toolCallId: null },
88+
{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{firstName }} {{lastName}}', toolCalls: null, toolCallId: null }
8989
]
9090
expect(extractVariablesFromMessages(messages)).toEqual(['name', 'firstName', 'lastName'])
9191
})
@@ -136,15 +136,19 @@ describe('template helpers', () => {
136136
})
137137
describe('renderMessagesWithVariabels', () => {
138138
it('renders a simple message with one variable', () => {
139-
const messages = [{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{name}}!' }]
139+
const messages = [{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{name}}!', toolCalls: null, toolCallId: null }]
140140
const variables = { name: 'Alice' }
141-
expect(renderMessagesWithVariabels(messages, variables)).toEqual([{ role: PromptMessageRoleEnum.USER, content: 'Hello, Alice!' }])
141+
expect(renderMessagesWithVariabels(messages, variables)).toEqual([
142+
{ role: PromptMessageRoleEnum.USER, content: 'Hello, Alice!', toolCalls: null, toolCallId: null }
143+
])
142144
})
143145

144146
it('renders a simple message with no variables', () => {
145-
const messages = [{ role: PromptMessageRoleEnum.USER, content: 'Hello!' }]
147+
const messages = [{ role: PromptMessageRoleEnum.USER, content: 'Hello!', toolCalls: null, toolCallId: null }]
146148
const variables = {}
147-
expect(renderMessagesWithVariabels(messages, variables)).toEqual([{ role: PromptMessageRoleEnum.USER, content: 'Hello!' }])
149+
expect(renderMessagesWithVariabels(messages, variables)).toEqual([
150+
{ role: PromptMessageRoleEnum.USER, content: 'Hello!', toolCalls: null, toolCallId: null }
151+
])
148152
})
149153
})
150154
describe('renderPromptWithVariables', () => {
@@ -162,13 +166,13 @@ describe('template helpers', () => {
162166
seed: 0
163167
},
164168
promptTools: [],
165-
promptMessages: [{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{name}}!' }]
169+
promptMessages: [{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{name}}!', toolCalls: null, toolCallId: null }]
166170
}
167171
const variables = { name: 'Alice' }
168172

169173
expect(renderPromptWithVariables(prompt, variables)).toEqual({
170174
...prompt,
171-
promptMessages: [{ role: PromptMessageRoleEnum.USER, content: 'Hello, Alice!' }]
175+
promptMessages: [{ role: PromptMessageRoleEnum.USER, content: 'Hello, Alice!', toolCalls: null, toolCallId: null }]
172176
})
173177
})
174178

@@ -186,13 +190,13 @@ describe('template helpers', () => {
186190
seed: 0
187191
},
188192
promptTools: [],
189-
promptMessages: [{ role: PromptMessageRoleEnum.USER, content: 'Hello!' }]
193+
promptMessages: [{ role: PromptMessageRoleEnum.USER, content: 'Hello!', toolCalls: null, toolCallId: null }]
190194
}
191195
const variables = {}
192196

193197
expect(renderPromptWithVariables(prompt, variables)).toEqual({
194198
...prompt,
195-
promptMessages: [{ role: PromptMessageRoleEnum.USER, content: 'Hello!' }]
199+
promptMessages: [{ role: PromptMessageRoleEnum.USER, content: 'Hello!', toolCalls: null, toolCallId: null }]
196200
})
197201
})
198202
})
@@ -211,7 +215,7 @@ describe('template helpers', () => {
211215
seed: 0
212216
},
213217
promptTools: [],
214-
promptMessages: [{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{name}}!' }]
218+
promptMessages: [{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{name}}!', toolCalls: null, toolCallId: null }]
215219
}
216220
const variables = { name: 'Alice' }
217221

@@ -232,7 +236,7 @@ describe('template helpers', () => {
232236
seed: 0
233237
},
234238
promptTools: [],
235-
promptMessages: [{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{name}}!' }]
239+
promptMessages: [{ role: PromptMessageRoleEnum.USER, content: 'Hello, {{name}}!', toolCalls: null, toolCallId: null }]
236240
}
237241
const variables = {}
238242

src/helpers/openAi.ts

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import {
22
ChatCompletionCreateParamsNonStreaming,
33
ChatCompletionMessageParam,
4+
ChatCompletionMessageToolCall,
45
ChatCompletionRole,
56
ChatCompletionTool,
67
ChatCompletionToolChoiceOption
@@ -9,10 +10,63 @@ import {
910
import { PromptConfiguration, PromptMessage, PromptTool } from '../types'
1011

1112
export const mapMessagesToOpenAI = (promptMessages: PromptMessage[]): ChatCompletionMessageParam[] => {
12-
return promptMessages.map((message) => ({
13-
role: message.role.toLowerCase() as ChatCompletionRole,
14-
content: message.content
15-
})) as ChatCompletionMessageParam[]
13+
return promptMessages.map((message): ChatCompletionMessageParam => {
14+
const role = message.role.toLowerCase() as ChatCompletionRole
15+
if (role === 'tool') {
16+
if (!message.toolCallId) {
17+
throw new Error('Tool call missing tool call id')
18+
}
19+
20+
if (!message.content) {
21+
throw new Error('Tool message missing content')
22+
}
23+
24+
return {
25+
role,
26+
content: message.content,
27+
tool_call_id: message.toolCallId
28+
}
29+
}
30+
if (role === 'assistant') {
31+
return {
32+
role,
33+
content: message.content,
34+
name: undefined,
35+
tool_calls: message.toolCalls?.map((toolCall): ChatCompletionMessageToolCall => {
36+
return {
37+
id: toolCall.toolCallId,
38+
type: toolCall.type,
39+
function: toolCall.function
40+
}
41+
})
42+
}
43+
}
44+
45+
if (role === 'user') {
46+
if (!message.content) {
47+
throw new Error('User message missing content')
48+
}
49+
50+
return {
51+
role,
52+
name: undefined,
53+
content: message.content
54+
}
55+
}
56+
57+
if (role === 'system') {
58+
if (!message.content) {
59+
throw new Error('System message missing content')
60+
}
61+
62+
return {
63+
role,
64+
content: message.content
65+
}
66+
}
67+
68+
throw new Error(`Invalid message role: ${role}`)
69+
})
1670
}
1771

1872
const mapToolChoiceToOpenAI = (tools: PromptTool[], toolChoice?: string | null): ChatCompletionToolChoiceOption => {

src/helpers/template.ts

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import { PromptConfiguration, PromptMessage } from '../types'
22

3-
export function extractVariables(template: string): string[] {
3+
export function extractVariables(template: string | null): string[] {
4+
if (!template) {
5+
return []
6+
}
47
const regex = /\{\{([^{}]+)\}\}/g
58
const variables = new Set<string>()
69
let match: RegExpExecArray | null
@@ -21,7 +24,10 @@ export function extractVariablesFromMessages(messages: PromptMessage[]): string[
2124
return Array.from(new Set(variables.flat()))
2225
}
2326

24-
export function renderTemplate(template: string, variables: Record<string, string>): string {
27+
export function renderTemplate(template: string | null, variables: Record<string, string>): string {
28+
if (!template) {
29+
return ''
30+
}
2531
return template.replace(/\{\{([^{}]+)\}\}/g, (_, key) => {
2632
return variables[key.trim()] || ''
2733
})
@@ -30,7 +36,7 @@ export function renderTemplate(template: string, variables: Record<string, strin
3036
export function renderMessagesWithVariabels(messages: PromptMessage[], variables: Record<string, string>): PromptMessage[] {
3137
return messages.map((message) => {
3238
return {
33-
role: message.role,
39+
...message,
3440
content: renderTemplate(message.content, variables)
3541
}
3642
})

src/types/openapi.d.ts

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,18 @@ declare namespace Components {
7979
* example:
8080
* Hello, {{city}}!
8181
*/
82-
content: string;
82+
content: string | null;
8383
/**
8484
* example:
85-
* user
85+
* USER
8686
*/
8787
role: "USER" | "ASSISTANT" | "SYSTEM" | "TOOL";
88+
/**
89+
* example:
90+
* TOOL_CALL_1
91+
*/
92+
toolCallId: string | null;
93+
toolCalls: ToolFunctionCall[] | null;
8894
}
8995
export interface Tool {
9096
/**
@@ -102,6 +108,33 @@ declare namespace Components {
102108
description: string;
103109
parameters: /* The parameters the functions accepts, described as a JSON Schema object. This schema is designed to match the TypeScript Record<string, unknown>, allowing for any properties with values of any type. */ ToolParameters;
104110
}
111+
export interface ToolFunctionCall {
112+
/**
113+
* example:
114+
* TOOL_CALL_1
115+
*/
116+
toolCallId: string;
117+
/**
118+
* The type of the tool. Currently, only `function` is supported.
119+
* example:
120+
* function
121+
*/
122+
type: "function";
123+
function: {
124+
/**
125+
* The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function.
126+
* example:
127+
* {}
128+
*/
129+
arguments: string;
130+
/**
131+
* The name of the function to call.
132+
* example:
133+
* checkWeather
134+
*/
135+
name: string;
136+
};
137+
}
105138
/**
106139
* The parameters the functions accepts, described as a JSON Schema object. This schema is designed to match the TypeScript Record<string, unknown>, allowing for any properties with values of any type.
107140
*/

0 commit comments

Comments
 (0)