Skip to content

Commit 8df76ce

Browse files
committed
openai mapping tests and bug fixes
1 parent ab6aec2 commit 8df76ce

File tree

8 files changed

+172
-8
lines changed

8 files changed

+172
-8
lines changed

src/__tests__/__snapshots__/Client.test.ts.snap

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ exports[`client getOpenAiPrompt returnsmapped prompt 1`] = `
2424
},
2525
"seed": null,
2626
"temperature": 0.7,
27-
"tool_choice": "auto",
28-
"tools": [],
27+
"tool_choice": "none",
28+
"tools": undefined,
2929
"top_p": 1,
3030
}
3131
`;
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import { createPromptConfigurationFixture } from '../../../test/__fixtures__/promptConfiguration'
2+
import { createPromptMessageFixture } from '../../../test/__fixtures__/promptMessage'
3+
import { createPromptToolFixture } from '../../../test/__fixtures__/promptTool'
4+
import { PromptMessageRoleEnum, PromptTool } from '../../types'
5+
import { mapMessagesToOpenAI, mapPromptToOpenAIConfig, mapToolChoiceToOpenAI } from '../openAi'
6+
7+
describe('openAi helpers', () => {
8+
describe('mapMessagesToOpenAI', () => {
9+
const mockMessages = [
10+
createPromptMessageFixture({ role: PromptMessageRoleEnum.TOOL, content: 'content1', toolCallId: 'id1' }),
11+
createPromptMessageFixture({ role: PromptMessageRoleEnum.ASSISTANT, content: 'content2' }),
12+
createPromptMessageFixture({ role: PromptMessageRoleEnum.USER, content: 'content3' }),
13+
createPromptMessageFixture({ role: PromptMessageRoleEnum.SYSTEM, content: 'content4' })
14+
]
15+
16+
it('should correctly map tool messages', () => {
17+
const result = mapMessagesToOpenAI([mockMessages[0]])
18+
expect(result[0]).toEqual({
19+
role: 'tool',
20+
content: 'content1',
21+
tool_call_id: 'id1'
22+
})
23+
})
24+
25+
it('should correctly map assistant messages', () => {
26+
const result = mapMessagesToOpenAI([mockMessages[1]])
27+
expect(result[0]).toEqual({
28+
role: 'assistant',
29+
content: 'content2',
30+
name: undefined,
31+
tool_calls: undefined
32+
})
33+
})
34+
35+
it('should correctly map user messages', () => {
36+
const result = mapMessagesToOpenAI([mockMessages[2]])
37+
expect(result[0]).toEqual({
38+
role: 'user',
39+
name: undefined,
40+
content: 'content3'
41+
})
42+
})
43+
44+
it('should correctly map system messages', () => {
45+
const result = mapMessagesToOpenAI([mockMessages[3]])
46+
expect(result[0]).toEqual({
47+
role: 'system',
48+
content: 'content4'
49+
})
50+
})
51+
52+
it('should throw an error for invalid message roles', () => {
53+
expect(() => mapMessagesToOpenAI([createPromptMessageFixture({ role: 'invalid', content: 'content5' } as any)])).toThrow(
54+
'Invalid message role: invalid'
55+
)
56+
})
57+
58+
// Additional test cases for missing fields and other edge cases can be added here.
59+
})
60+
61+
describe('mapPromptToOpenAIConfig', () => {
62+
const mockPromptConfig = createPromptConfigurationFixture()
63+
64+
it('should map the configuration to OpenAI parameters correctly', () => {
65+
const result = mapPromptToOpenAIConfig(mockPromptConfig)
66+
expect(result).toEqual({
67+
messages: [{ role: 'user', name: undefined, content: 'Hello world' }],
68+
model: 'text-davinci-002',
69+
top_p: 0.5,
70+
max_tokens: 150,
71+
temperature: 0.7,
72+
seed: 42,
73+
presence_penalty: 0.1,
74+
frequency_penalty: 0.1,
75+
tool_choice: 'none',
76+
response_format: {
77+
type: 'json_object'
78+
},
79+
tools: undefined
80+
})
81+
})
82+
83+
// Additional test cases to test other configurations and scenarios can be added here.
84+
})
85+
86+
describe('mapToolChoiceToOpenAI', () => {
87+
it('should return "auto" if toolChoice is "auto" or there are tools available and toolChoice is undefined', () => {
88+
const tools = [createPromptToolFixture()]
89+
expect(mapToolChoiceToOpenAI(tools, 'auto')).toBe('auto')
90+
expect(mapToolChoiceToOpenAI(tools)).toBe('auto')
91+
})
92+
93+
it('should return "none" if there are no tools or toolChoice is "none"', () => {
94+
const tools: PromptTool[] = []
95+
expect(mapToolChoiceToOpenAI(tools, 'none')).toBe('none')
96+
expect(mapToolChoiceToOpenAI(tools)).toBe('none')
97+
expect(mapToolChoiceToOpenAI(tools, 'auto')).toBe('none')
98+
})
99+
100+
it('should return tool function object if a valid toolChoice matches a tool', () => {
101+
const tools = [createPromptToolFixture({ name: 'exampleTool' })]
102+
expect(mapToolChoiceToOpenAI(tools, 'exampleTool')).toEqual({
103+
type: 'function',
104+
function: { name: 'exampleTool' }
105+
})
106+
})
107+
108+
it('should return "none" if the toolChoice does not match any tool', () => {
109+
const tools = [createPromptToolFixture({ name: 'exampleTool' })]
110+
expect(mapToolChoiceToOpenAI(tools, 'nonexistentTool')).toBe('none')
111+
})
112+
})
113+
})

src/helpers/openAi.ts

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,14 @@ export const mapMessagesToOpenAI = (promptMessages: PromptMessage[]): ChatComple
6969
})
7070
}
7171

72-
const mapToolChoiceToOpenAI = (tools: PromptTool[], toolChoice?: string | null): ChatCompletionToolChoiceOption => {
72+
export const mapToolChoiceToOpenAI = (tools: PromptTool[], toolChoice?: string | null): ChatCompletionToolChoiceOption => {
73+
if (tools.length === 0) {
74+
return 'none'
75+
}
7376
if (toolChoice === 'auto' || (!toolChoice && tools.length !== 0)) {
7477
return 'auto'
7578
}
76-
if (toolChoice === 'none' || tools.length === 0) {
79+
if (toolChoice === 'none') {
7780
return 'none'
7881
}
7982

@@ -103,6 +106,14 @@ const mapToolToOpenAi = (tool: PromptTool): ChatCompletionTool => {
103106
}
104107
}
105108

109+
const getTools = (promptTools: PromptTool[]): ChatCompletionTool[] | undefined => {
110+
if (promptTools.length === 0) {
111+
return undefined
112+
}
113+
114+
return promptTools.map((tool) => mapToolToOpenAi(tool))
115+
}
116+
106117
export const mapPromptToOpenAIConfig = (promptConfig: PromptConfiguration): ChatCompletionCreateParamsNonStreaming => {
107118
const { promptMessages, promptParameters, promptTools } = promptConfig
108119

@@ -121,6 +132,6 @@ export const mapPromptToOpenAIConfig = (promptConfig: PromptConfiguration): Chat
121132
response_format: {
122133
type: promptParameters.responseFormat === 'JSON' ? 'json_object' : 'text'
123134
},
124-
tools: promptTools.map((tool) => mapToolToOpenAi(tool))
135+
tools: getTools(promptTools)
125136
}
126137
}

src/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import Client from './Client'
22

33
export * from './helpers'
4-
export type * from './types'
4+
export * from './types'
55

66
export default Client
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import { PromptConfiguration } from '../../src'
2+
3+
import { createPromptMessageFixture } from './promptMessage'
4+
5+
export const createPromptConfigurationFixture = (overrides: Partial<PromptConfiguration> = {}): PromptConfiguration => ({
6+
promptId: 'promptId',
7+
promptMessages: [createPromptMessageFixture()],
8+
promptParameters: {
9+
modelName: 'text-davinci-002',
10+
topP: 0.5,
11+
maxTokens: 150,
12+
temperature: 0.7,
13+
seed: 42,
14+
presencePenalty: 0.1,
15+
frequencyPenalty: 0.1,
16+
toolChoice: 'auto',
17+
responseFormat: 'JSON'
18+
},
19+
promptTools: [],
20+
...overrides
21+
})

test/__fixtures__/promptMessage.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import { PromptMessage, PromptMessageRoleEnum } from '../../src'
2+
3+
export const createPromptMessageFixture = (overrides: Partial<PromptMessage> = {}): PromptMessage => ({
4+
content: 'Hello world',
5+
role: PromptMessageRoleEnum.USER,
6+
toolCallId: null,
7+
toolCalls: null,
8+
...overrides
9+
})

test/__fixtures__/promptTool.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import { PromptTool } from '../../src'
2+
3+
export const createPromptToolFixture = (overrides: Partial<PromptTool> = {}): PromptTool => ({
4+
toolId: 'toolId',
5+
description: 'description',
6+
name: 'name',
7+
parameters: {},
8+
...overrides
9+
})

tsconfig.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"incremental": true,
44
"target": "es2017",
55
"outDir": "build/main",
6-
"rootDir": "src",
6+
"rootDir": "./",
77
"moduleResolution": "node",
88
"module": "commonjs",
99
"declaration": true,
@@ -28,7 +28,8 @@
2828
]
2929
},
3030
"include": [
31-
"src/**/*.ts"
31+
"src/**/*.ts",
32+
"test/**/*.ts"
3233
],
3334
"exclude": [
3435
"node_modules/**"

0 commit comments

Comments
 (0)