Skip to content

Commit 4031b83

Browse files
committed
update openai mapping
1 parent 5ecc8d9 commit 4031b83

File tree

3 files changed

+132
-16
lines changed

3 files changed

+132
-16
lines changed

src/Client.ts

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import { ChatCompletionCreateParamsNonStreaming } from 'openai/resources'
1+
import { ChatCompletionCreateParamsNonStreaming, ChatCompletionMessageParam } from 'openai/resources'
22

33
import { addAppendedMessages, addOverrideMessages, getMissingPromptVariables, renderPromptWithVariables, validatePromptVariables } from './helpers'
4-
import { mapPromptToOpenAIConfig } from './helpers/openAi'
4+
import { mapOpenAIMessagesToMessages, mapPromptToOpenAIConfig } from './helpers/openAi'
55
import { api, createApiClient } from './openapi/client'
66
import {
77
Evaluation,
@@ -76,11 +76,19 @@ export default class PromptFoundry {
7676
}: {
7777
id: string
7878
variables: Record<string, string>
79-
appendMessages?: PromptMessage[]
80-
overrideMessages?: PromptMessage[]
79+
appendMessages?: ChatCompletionMessageParam[]
80+
overrideMessages?: ChatCompletionMessageParam[]
8181
user?: string
8282
}): Promise<ChatCompletionCreateParamsNonStreaming> {
83-
const updatedWithVariables = await this.getPrompt({ id, variables, appendMessages, overrideMessages })
83+
const appendMessagesMapped = appendMessages ? mapOpenAIMessagesToMessages(appendMessages) : undefined
84+
const overrideMessagesMapped = overrideMessages ? mapOpenAIMessagesToMessages(overrideMessages) : undefined
85+
86+
const updatedWithVariables = await this.getPrompt({
87+
id,
88+
variables,
89+
appendMessages: appendMessagesMapped,
90+
overrideMessages: overrideMessagesMapped
91+
})
8492

8593
return mapPromptToOpenAIConfig(updatedWithVariables, {
8694
user

src/helpers/__tests__/openAi.test.ts

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ import { createPromptConfigurationFixture } from '../../test/__fixtures__/prompt
22
import { createPromptMessageFixture } from '../../test/__fixtures__/promptMessage'
33
import { createPromptToolFixture } from '../../test/__fixtures__/promptTool'
44
import { PromptMessageRoleEnum, PromptTool } from '../../types'
5-
import { mapMessagesToOpenAI, mapPromptToOpenAIConfig, mapToolChoiceToOpenAI } from '../openAi'
5+
import { mapMessagesToOpenAIMessages, mapOpenAIMessagesToMessages, mapPromptToOpenAIConfig, mapToolChoiceToOpenAI } from '../openAi'
66

77
describe('openAi helpers', () => {
8-
describe('mapMessagesToOpenAI', () => {
8+
describe('mapMessagesToOpenAIMessages', () => {
99
const mockMessages = [
1010
createPromptMessageFixture({ role: PromptMessageRoleEnum.TOOL, content: 'content1', toolCallId: 'id1' }),
1111
createPromptMessageFixture({ role: PromptMessageRoleEnum.ASSISTANT, content: 'content2' }),
@@ -14,7 +14,7 @@ describe('openAi helpers', () => {
1414
]
1515

1616
it('should correctly map tool messages', () => {
17-
const result = mapMessagesToOpenAI([mockMessages[0]])
17+
const result = mapMessagesToOpenAIMessages([mockMessages[0]])
1818
expect(result[0]).toEqual({
1919
role: 'tool',
2020
content: 'content1',
@@ -23,7 +23,7 @@ describe('openAi helpers', () => {
2323
})
2424

2525
it('should correctly map assistant messages', () => {
26-
const result = mapMessagesToOpenAI([mockMessages[1]])
26+
const result = mapMessagesToOpenAIMessages([mockMessages[1]])
2727
expect(result[0]).toEqual({
2828
role: 'assistant',
2929
content: 'content2',
@@ -33,7 +33,7 @@ describe('openAi helpers', () => {
3333
})
3434

3535
it('should correctly map user messages', () => {
36-
const result = mapMessagesToOpenAI([mockMessages[2]])
36+
const result = mapMessagesToOpenAIMessages([mockMessages[2]])
3737
expect(result[0]).toEqual({
3838
role: 'user',
3939
name: undefined,
@@ -42,22 +42,47 @@ describe('openAi helpers', () => {
4242
})
4343

4444
it('should correctly map system messages', () => {
45-
const result = mapMessagesToOpenAI([mockMessages[3]])
45+
const result = mapMessagesToOpenAIMessages([mockMessages[3]])
4646
expect(result[0]).toEqual({
4747
role: 'system',
4848
content: 'content4'
4949
})
5050
})
5151

5252
it('should throw an error for invalid message roles', () => {
53-
expect(() => mapMessagesToOpenAI([createPromptMessageFixture({ role: 'invalid', content: 'content5' } as any)])).toThrow(
53+
expect(() => mapMessagesToOpenAIMessages([createPromptMessageFixture({ role: 'invalid', content: 'content5' } as any)])).toThrow(
5454
'Invalid message role: invalid'
5555
)
5656
})
5757

5858
// Additional test cases for missing fields and other edge cases can be added here.
5959
})
6060

61+
describe('mapOpenAIMessagesToMessages', () => {
62+
const mockMessages = [
63+
createPromptMessageFixture({ role: PromptMessageRoleEnum.TOOL, content: 'content1', toolCallId: 'id1' }),
64+
createPromptMessageFixture({ role: PromptMessageRoleEnum.ASSISTANT, content: 'content2' }),
65+
createPromptMessageFixture({ role: PromptMessageRoleEnum.USER, content: 'content3' }),
66+
createPromptMessageFixture({ role: PromptMessageRoleEnum.SYSTEM, content: 'content4' })
67+
]
68+
69+
it('should correctly map tool messages', () => {
70+
expect(mapOpenAIMessagesToMessages(mapMessagesToOpenAIMessages([mockMessages[0]]))).toEqual([mockMessages[0]])
71+
})
72+
73+
it('should correctly map assistant messages', () => {
74+
expect(mapOpenAIMessagesToMessages(mapMessagesToOpenAIMessages([mockMessages[1]]))).toEqual([mockMessages[1]])
75+
})
76+
77+
it('should correctly map user messages', () => {
78+
expect(mapOpenAIMessagesToMessages(mapMessagesToOpenAIMessages([mockMessages[2]]))).toEqual([mockMessages[2]])
79+
})
80+
81+
it('should correctly map system messages', () => {
82+
expect(mapOpenAIMessagesToMessages(mapMessagesToOpenAIMessages([mockMessages[3]]))).toEqual([mockMessages[3]])
83+
})
84+
})
85+
6186
describe('mapPromptToOpenAIConfig', () => {
6287
it('should map the configuration to OpenAI parameters correctly', () => {
6388
const result = mapPromptToOpenAIConfig(createPromptConfigurationFixture())

src/helpers/openAi.ts

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import {
2+
ChatCompletionAssistantMessageParam,
23
ChatCompletionCreateParamsNonStreaming,
34
ChatCompletionMessageParam,
45
ChatCompletionMessageToolCall,
56
ChatCompletionRole,
67
ChatCompletionTool,
7-
ChatCompletionToolChoiceOption
8+
ChatCompletionToolChoiceOption,
9+
ChatCompletionToolMessageParam
810
} from 'openai/resources'
911

10-
import { PromptConfiguration, PromptMessage, PromptTool } from '../types'
12+
import { PromptConfiguration, PromptMessage, PromptMessageRole, PromptTool } from '../types'
1113

12-
export const mapMessagesToOpenAI = (messages: PromptMessage[]): ChatCompletionMessageParam[] => {
14+
export const mapMessagesToOpenAIMessages = (messages: PromptMessage[]): ChatCompletionMessageParam[] => {
1315
return messages.map((message): ChatCompletionMessageParam => {
1416
const role = message.role.toLowerCase() as ChatCompletionRole
1517
if (role === 'tool') {
@@ -69,6 +71,87 @@ export const mapMessagesToOpenAI = (messages: PromptMessage[]): ChatCompletionMe
6971
})
7072
}
7173

74+
function isToolMessage(message: ChatCompletionMessageParam): message is ChatCompletionToolMessageParam {
75+
return message.role === 'tool'
76+
}
77+
78+
function isAssistantMessage(message: ChatCompletionMessageParam): message is ChatCompletionAssistantMessageParam {
79+
return message.role === 'assistant'
80+
}
81+
82+
export const mapOpenAIMessagesToMessages = (messages: ChatCompletionMessageParam[]): PromptMessage[] => {
83+
return messages.map((message): PromptMessage => {
84+
const role = message.role.toUpperCase() as PromptMessageRole
85+
86+
if (isToolMessage(message)) {
87+
if (!message.tool_call_id) {
88+
throw new Error('Tool call missing tool call id')
89+
}
90+
91+
if (!message.content) {
92+
throw new Error('Tool message missing content')
93+
}
94+
return {
95+
role,
96+
content: message.content,
97+
toolCallId: message.tool_call_id,
98+
toolCalls: null
99+
}
100+
}
101+
102+
if (isAssistantMessage(message)) {
103+
const toolCalls: PromptMessage['toolCalls'] =
104+
message.tool_calls?.map((toolCall) => {
105+
return {
106+
toolCallId: toolCall.id,
107+
type: toolCall.type,
108+
function: {
109+
name: toolCall.function.name,
110+
arguments: toolCall.function.arguments
111+
}
112+
}
113+
}) || null
114+
115+
return {
116+
role,
117+
content: message.content as string,
118+
name: undefined,
119+
toolCallId: null,
120+
toolCalls
121+
}
122+
}
123+
124+
if (role === 'USER') {
125+
if (!message.content) {
126+
throw new Error('User message missing content')
127+
}
128+
129+
return {
130+
role,
131+
name: undefined,
132+
content: message.content as string,
133+
toolCallId: null,
134+
toolCalls: null
135+
}
136+
}
137+
138+
if (role === 'SYSTEM') {
139+
if (!message.content) {
140+
throw new Error('System message missing content')
141+
}
142+
143+
return {
144+
role,
145+
content: message.content as string,
146+
toolCallId: null,
147+
toolCalls: null
148+
}
149+
}
150+
151+
throw new Error(`Invalid message role: ${role as string}`)
152+
})
153+
}
154+
72155
export const mapToolChoiceToOpenAI = (tools: PromptTool[], toolChoice?: string | null): ChatCompletionToolChoiceOption | undefined => {
73156
if (tools.length === 0) {
74157
return undefined
@@ -120,7 +203,7 @@ export const mapPromptToOpenAIConfig = (
120203
): ChatCompletionCreateParamsNonStreaming => {
121204
const { messages: promptMessages, parameters, tools } = promptConfig
122205

123-
const messages = mapMessagesToOpenAI(promptMessages)
206+
const messages = mapMessagesToOpenAIMessages(promptMessages)
124207

125208
return {
126209
messages,

0 commit comments

Comments
 (0)