Skip to content

Commit 6b23702

Browse files
committed
add an example for using functions to calculate standard deviation
1 parent fa5f3db commit 6b23702

File tree

1 file changed

+129
-0
lines changed

1 file changed

+129
-0
lines changed

examples/example5.lua

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
2+
-- This example that will attempt to calcualte the standard deviation of a
3+
-- number by calling functions tcalling functions that have been defined.
4+
--
5+
6+
local cjson = require("cjson")
7+
local types = require("tableshape").types
8+
9+
local openai = require("openai")
10+
local client = openai.new(os.getenv("OPENAI_API_KEY"))
11+
12+
-- Helper debug print function to show contents of table as json
13+
local function p(...)
14+
local chunks = {...}
15+
for i, chunk in ipairs(chunks) do
16+
if type(chunk) ~= "string" then
17+
chunks[i] = cjson.encode(chunk)
18+
end
19+
end
20+
21+
print(unpack(chunks))
22+
end
23+
24+
local two_numbers = {
25+
type = "object",
26+
properties = {
27+
a = { type = "number" },
28+
b = { type = "number" }
29+
}
30+
}
31+
32+
local chat = client:new_chat_session({
33+
temperature = 0,
34+
-- model = "gpt-3.5-turbo-0613",
35+
model = "gpt-4-0613",
36+
messages = {
37+
{
38+
role = "system",
39+
content = "You are a calculator with access to specified set of functions. All computation should be done with the functions"
40+
}
41+
},
42+
functions = {
43+
{ name = "add", description = "Add two numbers together", parameters = two_numbers },
44+
{ name = "divide", description = "Divide two numbers", parameters = two_numbers },
45+
{ name = "multiply", description = "Multiply two numbers together", parameters = two_numbers },
46+
{
47+
name = "sqrt", description = "Calculate square root of a number",
48+
parameters = {
49+
type = "object",
50+
properties = {
51+
a = { type = "number" }
52+
}
53+
}
54+
}
55+
}
56+
})
57+
58+
-- override send method with logging
59+
local chat_send = chat.send
60+
function chat:send(v, ...)
61+
p(">>", v)
62+
return chat_send(self, v, ...)
63+
end
64+
65+
66+
local one_args = types.annotate(types.string / cjson.decode * types.partial({
67+
a = types.number
68+
}))
69+
70+
local two_args = types.annotate(types.string / cjson.decode * types.partial({
71+
a = types.number,
72+
b = types.number
73+
}))
74+
75+
local funcs = {
76+
add = {
77+
arguments = two_args,
78+
call = function(args) return args.a + args.b end
79+
},
80+
divide = {
81+
arguments = two_args,
82+
call = function(args) return args.a / args.b end
83+
},
84+
multiply = {
85+
arguments = two_args,
86+
call = function(args) return args.a * args.b end
87+
},
88+
sqrt = {
89+
arguments = one_args,
90+
call = function(args)
91+
return math.sqrt(args.a)
92+
end
93+
}
94+
}
95+
96+
assert(chat:send("Calculate the standard deviation of the numbers: 2, 8, 28, 92, 9"))
97+
98+
while true do
99+
local last_message = chat:last_message()
100+
101+
for k, v in pairs(last_message) do
102+
p(k, v)
103+
end
104+
105+
-- stop if no functions are requested
106+
if not last_message.function_call then
107+
break
108+
end
109+
110+
local func = last_message.function_call
111+
local func_handler = funcs[func.name]
112+
113+
if not func_handler then
114+
assert(chat:send("You called a function that is not declared: " .. func.name))
115+
else
116+
local arguments, err = func_handler.arguments:transform(func.arguments)
117+
if not arguments then
118+
assert(chat:send("Invalid arguments for function " .. func.name .. ": " .. err))
119+
else
120+
local result = func_handler.call(arguments)
121+
assert(chat:send({
122+
role = "function",
123+
name = func.name,
124+
content = cjson.encode(result)
125+
}))
126+
end
127+
end
128+
end
129+

0 commit comments

Comments
 (0)