Skip to content

Commit b467b1d

Browse files
committed
add support for functions for chat API
1 parent fb545c0 commit b467b1d

File tree

2 files changed

+117
-31
lines changed

2 files changed

+117
-31
lines changed

openai/init.lua

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,48 @@ local unpack = table.unpack or unpack
55
local types
66
types = require("tableshape").types
77
local parse_url = require("socket.url").parse
8-
local test_message = types.shape({
9-
role = types.one_of({
10-
"system",
11-
"user",
12-
"assistant"
8+
local empty = (types["nil"] + types.literal(cjson.null)):describe("nullable")
9+
local test_message = types.one_of({
10+
types.shape({
11+
role = types.one_of({
12+
"system",
13+
"user",
14+
"assistant"
15+
}),
16+
content = empty + types.string,
17+
name = empty + types.string,
18+
function_call = empty + types.table
1319
}),
14-
content = types.string,
15-
name = types["nil"] + types.string
20+
types.shape({
21+
role = types.one_of({
22+
"function"
23+
}),
24+
name = types.string,
25+
content = empty + types.string
26+
})
27+
})
28+
local test_function = types.shape({
29+
name = types.string,
30+
description = types["nil"] + types.string,
31+
parameters = types["nil"] + types.table
1632
})
1733
local parse_chat_response = types.partial({
1834
usage = types.table:tag("usage"),
1935
choices = types.partial({
2036
types.partial({
21-
message = types.partial({
22-
content = types.string:tag("response"),
23-
role = "assistant"
37+
message = types.one_of({
38+
types.partial({
39+
role = "assistant",
40+
content = types.string + empty,
41+
function_call = types.partial({
42+
name = types.string,
43+
arguments = types.string
44+
})
45+
}),
46+
types.partial({
47+
role = "assistant",
48+
content = types.string:tag("response")
49+
})
2450
}):tag("message")
2551
})
2652
})
@@ -62,7 +88,7 @@ end
6288
local parse_error_message = types.partial({
6389
error = types.partial({
6490
message = types.string:tag("message"),
65-
code = types.string:tag("code")
91+
code = empty + types.string:tag("code")
6692
})
6793
})
6894
local ChatSession
@@ -100,18 +126,23 @@ do
100126
stream_callback = nil
101127
end
102128
local status, response = self.client:chat(self.messages, {
129+
function_call = self.opts.function_call,
130+
functions = self.functions,
103131
model = self.opts.model,
104132
temperature = self.opts.temperature,
105133
stream = stream_callback and true or nil
106134
}, stream_callback)
107135
if status ~= 200 then
108-
local err_msg
136+
local err_msg = "Bad status: " .. tostring(status)
109137
do
110138
local err = parse_error_message(response)
111139
if err then
112-
err_msg = "Bad status: " .. tostring(status) .. ": " .. tostring(err.message) .. " (" .. tostring(err.code) .. ")"
113-
else
114-
err_msg = "Bad status: " .. tostring(status)
140+
if err.message then
141+
err_msg = err_msg .. ": " .. tostring(err.message)
142+
end
143+
if err.code then
144+
err_msg = err_msg .. " (" .. tostring(err.code) .. ")"
145+
end
115146
end
116147
end
117148
return nil, err_msg, response
@@ -139,7 +170,7 @@ do
139170
if append_response then
140171
self:append_message(out.message)
141172
end
142-
return out.response
173+
return out.response or out.message
143174
end
144175
}
145176
_base_0.__index = _base_0
@@ -150,8 +181,17 @@ do
150181
end
151182
self.client, self.opts = client, opts
152183
self.messages = { }
184+
self.functions = { }
153185
if type(self.opts.messages) == "table" then
154-
return self:append_message(unpack(self.opts.messages))
186+
self:append_message(unpack(self.opts.messages))
187+
end
188+
if type(self.opts.functions) == "table" then
189+
local _list_0 = self.opts.functions
190+
for _index_0 = 1, #_list_0 do
191+
local func = _list_0[_index_0]
192+
assert(test_function(func))
193+
table.insert(self.functions, func)
194+
end
155195
end
156196
end,
157197
__base = _base_0,

openai/init.moon

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,52 @@ import types from require "tableshape"
99

1010
parse_url = require("socket.url").parse
1111

12-
test_message = types.shape {
13-
role: types.one_of {"system", "user", "assistant"}
14-
content: types.string
15-
name: types.nil + types.string
12+
empty = (types.nil + types.literal(cjson.null))\describe "nullable"
13+
14+
test_message = types.one_of {
15+
types.shape {
16+
role: types.one_of {"system", "user", "assistant"}
17+
content: empty + types.string -- this can be empty when function_call is set
18+
name: empty + types.string
19+
function_call: empty + types.table
20+
}
21+
22+
-- this message type is for sending a function call response back
23+
types.shape {
24+
role: types.one_of {"function"}
25+
name: types.string
26+
content: empty + types.string
27+
}
28+
}
29+
30+
-- verify the shape of a function declaration
31+
test_function = types.shape {
32+
name: types.string
33+
description: types.nil + types.string
34+
parameters: types.nil + types.table
1635
}
1736

1837
parse_chat_response = types.partial {
1938
usage: types.table\tag "usage"
2039
choices: types.partial {
2140
types.partial {
22-
message: types.partial({
23-
content: types.string\tag "response"
24-
role: "assistant"
41+
message: types.one_of({
42+
-- if function call is requested, content is not required so we tag
43+
-- nothing so we can return the whole object
44+
types.partial({
45+
role: "assistant"
46+
content: types.string + empty
47+
function_call: types.partial {
48+
name: types.string
49+
-- API returns arguments a string that should be in json format
50+
arguments: types.string
51+
}
52+
})
53+
54+
types.partial {
55+
role: "assistant"
56+
content: types.string\tag "response"
57+
}
2558
})\tag "message"
2659
}
2760
}
@@ -81,7 +114,7 @@ consume_json_head = do
81114
parse_error_message = types.partial {
82115
error: types.partial {
83116
message: types.string\tag "message"
84-
code: types.string\tag "code"
117+
code: empty + types.string\tag "code"
85118
}
86119
}
87120

@@ -90,9 +123,16 @@ parse_error_message = types.partial {
90123
class ChatSession
91124
new: (@client, @opts={}) =>
92125
@messages = {}
126+
@functions = {}
127+
93128
if type(@opts.messages) == "table"
94129
@append_message unpack @opts.messages
95130

131+
if type(@opts.functions) == "table"
132+
for func in *@opts.functions
133+
assert test_function func
134+
table.insert @functions, func
135+
96136
append_message: (m, ...) =>
97137
assert test_message m
98138
table.insert @messages, m
@@ -119,16 +159,22 @@ class ChatSession
119159
-- stream_callback: provide a function to enable streaming output. function will receive each chunk as it's generated
120160
generate_response: (append_response=true, stream_callback=nil) =>
121161
status, response = @client\chat @messages, {
162+
function_call: @opts.function_call -- override the default function call behavior
163+
functions: @functions
122164
model: @opts.model
123165
temperature: @opts.temperature
124166
stream: stream_callback and true or nil
125167
}, stream_callback
126168

127169
if status != 200
128-
err_msg = if err = parse_error_message response
129-
"Bad status: #{status}: #{err.message} (#{err.code})"
130-
else
131-
"Bad status: #{status}"
170+
err_msg = "Bad status: #{status}"
171+
172+
if err = parse_error_message response
173+
if err.message
174+
err_msg ..= ": #{err.message}"
175+
176+
if err.code
177+
err_msg ..= " (#{err.code})"
132178

133179
return nil, err_msg, response
134180

@@ -160,8 +206,8 @@ class ChatSession
160206
if append_response
161207
@append_message out.message
162208

163-
out.response
164-
209+
-- response is missing for function_calls, so we return the entire message object
210+
out.response or out.message
165211

166212
class OpenAI
167213
api_base: "https://api.openai.com/v1"

0 commit comments

Comments
 (0)