@@ -9,19 +9,52 @@ import types from require "tableshape"
99
1010parse_url = require ( " socket.url" ) . parse
1111
12- test_message = types. shape {
13- role : types. one_of { " system" , " user" , " assistant" }
14- content : types. string
15- name : types. nil + types. string
12+ empty = ( types. nil + types. literal( cjson. null)) \ describe " nullable"
13+
14+ test_message = types. one_of {
15+ types. shape {
16+ role : types. one_of { " system" , " user" , " assistant" }
17+ content : empty + types. string -- this can be empty when function_call is set
18+ name : empty + types. string
19+ function_call : empty + types. table
20+ }
21+
22+ -- this message type is for sending a function call response back
23+ types. shape {
24+ role : types. one_of { " function" }
25+ name : types. string
26+ content : empty + types. string
27+ }
28+ }
29+
30+ -- verify the shape of a function declaration
31+ test_function = types. shape {
32+ name : types. string
33+ description : types. nil + types. string
34+ parameters : types. nil + types. table
1635}
1736
1837parse_chat_response = types. partial {
1938 usage : types. table\ tag " usage"
2039 choices : types. partial {
2140 types. partial {
22- message : types. partial( {
23- content : types. string\ tag " response"
24- role : " assistant"
41+ message : types. one_of( {
42+ -- if function call is requested, content is not required so we tag
43+ -- nothing so we can return the whole object
44+ types. partial( {
45+ role : " assistant"
46+ content : types. string + empty
47+ function_call : types. partial {
48+ name : types. string
49+ -- API returns arguments a string that should be in json format
50+ arguments : types. string
51+ }
52+ } )
53+
54+ types. partial {
55+ role : " assistant"
56+ content : types. string\ tag " response"
57+ }
2558 } ) \ tag " message"
2659 }
2760 }
@@ -81,7 +114,7 @@ consume_json_head = do
81114parse_error_message = types. partial {
82115 error : types. partial {
83116 message : types. string\ tag " message"
84- code : types. string\ tag " code"
117+ code : empty + types. string\ tag " code"
85118 }
86119}
87120
@@ -90,9 +123,16 @@ parse_error_message = types.partial {
90123class ChatSession
91124 new : ( @client , @opts = {} ) =>
92125 @messages = {}
126+ @functions = {}
127+
93128 if type ( @opts . messages) == " table"
94129 @append_message unpack @opts . messages
95130
131+ if type ( @opts . functions) == " table"
132+ for func in * @opts . functions
133+ assert test_function func
134+ table.insert @functions , func
135+
96136 append_message : ( m, ... ) =>
97137 assert test_message m
98138 table.insert @messages , m
@@ -119,16 +159,22 @@ class ChatSession
119159 -- stream_callback: provide a function to enable streaming output. function will receive each chunk as it's generated
120160 generate_response : ( append_response= true , stream_callback= nil ) =>
121161 status, response = @client \ chat @messages , {
162+ function_call : @opts . function_call -- override the default function call behavior
163+ functions : @functions
122164 model : @opts . model
123165 temperature : @opts . temperature
124166 stream : stream_callback and true or nil
125167 } , stream_callback
126168
127169 if status != 200
128- err_msg = if err = parse_error_message response
129- " Bad status: #{status}: #{err.message} (#{err.code})"
130- else
131- " Bad status: #{status}"
170+ err_msg = " Bad status: #{status}"
171+
172+ if err = parse_error_message response
173+ if err. message
174+ err_msg ..= " : #{err.message}"
175+
176+ if err. code
177+ err_msg ..= " (#{err.code})"
132178
133179 return nil , err_msg, response
134180
@@ -160,8 +206,8 @@ class ChatSession
160206 if append_response
161207 @append_message out. message
162208
163- out . response
164-
209+ -- response is missing for function_calls, so we return the entire message object
210+ out . response or out . message
165211
166212class OpenAI
167213 api_base : " https://api.openai.com/v1"
0 commit comments