@@ -24,7 +24,7 @@ function differentInputTextAccepted(testCase, ValidTextInput)
2424 testCase .verifyWarningFree(@()addSystemMessage(msgs , ValidTextInput , ValidTextInput ));
2525 testCase .verifyWarningFree(@()addSystemMessage(msgs , ValidTextInput , ValidTextInput ));
2626 testCase .verifyWarningFree(@()addUserMessage(msgs , ValidTextInput ));
27- testCase .verifyWarningFree(@()addFunctionMessage (msgs , ValidTextInput , ValidTextInput ));
27+ testCase .verifyWarningFree(@()addToolMessage (msgs , ValidTextInput , ValidTextInput , ValidTextInput ));
2828 end
2929
3030
@@ -59,12 +59,13 @@ function userImageMessageIsAddedWithRemoteImg(testCase)
5959 testCase .verifyWarningFree(@()addUserMessageWithImages(msgs , prompt , img ));
6060 end
6161
62- function functionMessageIsAdded (testCase )
62+ function toolMessageIsAdded (testCase )
6363 prompt = " 20" ;
6464 name = " sin" ;
65+ id = " 123" ;
6566 msgs = openAIMessages ;
66- systemPrompt = struct(" role" , " function " , " name" , name , " content" , prompt );
67- msgs = addFunctionMessage (msgs , name , prompt );
67+ systemPrompt = struct(" tool_call_id " , id , " role" , " tool " , " name" , name , " content" , prompt );
68+ msgs = addToolMessage (msgs , id , name , prompt );
6869 testCase .verifyEqual(msgs.Messages{1 }, systemPrompt );
6970 end
7071
@@ -76,27 +77,39 @@ function assistantMessageIsAdded(testCase)
7677 testCase .verifyEqual(msgs.Messages{1 }, assistantPrompt );
7778 end
7879
79- function assistantFunctionCallMessageIsAdded (testCase )
80+ function assistantToolCallMessageIsAdded (testCase )
8081 msgs = openAIMessages ;
8182 functionName = " functionName" ;
8283 args = " {"" arg1"" : 1, "" arg2"" : 2, "" arg3"" : "" 3"" }" ;
8384 funCall = struct(" name" , functionName , " arguments" , args );
8485 toolCall = struct(" id" , " 123" , " type" , " function" , " function" , funCall );
85- functionCallPrompt = struct(" role" , " assistant" , " content" , " " ," tool_calls" , toolCall );
86- functionCallPrompt .tool_calls = {functionCallPrompt . tool_calls };
87- msgs = addResponseMessage(msgs , functionCallPrompt );
88- testCase .verifyEqual(msgs.Messages{1 }, functionCallPrompt );
86+ toolCallPrompt = struct(" role" , " assistant" , " content" , " " , " tool_calls" , [] );
87+ toolCallPrompt .tool_calls = {toolCall };
88+ msgs = addResponseMessage(msgs , toolCallPrompt );
89+ testCase .verifyEqual(msgs.Messages{1 }, toolCallPrompt );
8990 end
9091
91- function assistantFunctionCallMessageWithoutArgsIsAdded (testCase )
92+ function assistantToolCallMessageWithoutArgsIsAdded (testCase )
9293 msgs = openAIMessages ;
9394 functionName = " functionName" ;
9495 funCall = struct(" name" , functionName , " arguments" , " {}" );
9596 toolCall = struct(" id" , " 123" , " type" , " function" , " function" , funCall );
96- functionCallPrompt = struct(" role" , " assistant" , " content" , " " ," tool_calls" , toolCall );
97- functionCallPrompt.tool_calls = {functionCallPrompt .tool_calls };
98- msgs = addResponseMessage(msgs , functionCallPrompt );
99- testCase .verifyEqual(msgs.Messages{1 }, functionCallPrompt );
97+ toolCallPrompt = struct(" role" , " assistant" , " content" , " " ," tool_calls" , []);
98+ toolCallPrompt.tool_calls = {toolCall };
99+ msgs = addResponseMessage(msgs , toolCallPrompt );
100+ testCase .verifyEqual(msgs.Messages{1 }, toolCallPrompt );
101+ end
102+
103+ function assistantParallelToolCallMessageIsAdded(testCase )
104+ msgs = openAIMessages ;
105+ functionName = " functionName" ;
106+ args = " {"" arg1"" : 1, "" arg2"" : 2, "" arg3"" : "" 3"" }" ;
107+ funCall = struct(" name" , functionName , " arguments" , args );
108+ toolCall = struct(" id" , " 123" , " type" , " function" , " function" , funCall );
109+ toolCallPrompt = struct(" role" , " assistant" , " content" , " " , " tool_calls" , []);
110+ toolCallPrompt.tool_calls = [toolCall ,toolCall ,toolCall ];
111+ msgs = addResponseMessage(msgs , toolCallPrompt );
112+ testCase .verifyEqual(msgs.Messages{1 }, toolCallPrompt );
100113 end
101114
102115 function messageGetsRemoved(testCase )
@@ -105,7 +118,7 @@ function messageGetsRemoved(testCase)
105118
106119 msgs = addSystemMessage(msgs , " name" , " content" );
107120 msgs = addUserMessage(msgs , " content" );
108- msgs = addFunctionMessage (msgs , " name" , " content" );
121+ msgs = addToolMessage (msgs , " 123 " , " name" , " content" );
109122 sizeMsgs = length(msgs .Messages );
110123 % Message exists before removal
111124 msgToBeRemoved = msgs.Messages{idx };
@@ -121,7 +134,7 @@ function removalIdxCantBeLargerThanNumElements(testCase)
121134
122135 msgs = addSystemMessage(msgs , " name" , " content" );
123136 msgs = addUserMessage(msgs , " content" );
124- msgs = addFunctionMessage (msgs , " name" , " content" );
137+ msgs = addToolMessage (msgs , " 123 " , " name" , " content" );
125138 sizeMsgs = length(msgs .Messages );
126139
127140 testCase .verifyError(@()removeMessage(msgs , sizeMsgs + 1 ), " llms:mustBeValidIndex" );
@@ -144,7 +157,7 @@ function invalidInputsUserImagesPrompt(testCase, InvalidInputsUserImagesPrompt)
144157
145158 function invalidInputsFunctionPrompt(testCase , InvalidInputsFunctionPrompt )
146159 msgs = openAIMessages ;
147- testCase .verifyError(@()addFunctionMessage (msgs ,InvalidInputsFunctionPrompt.Input{: }), InvalidInputsFunctionPrompt .Error );
160+ testCase .verifyError(@()addToolMessage (msgs ,InvalidInputsFunctionPrompt.Input{: }), InvalidInputsFunctionPrompt .Error );
148161 end
149162
150163 function invalidInputsRemove(testCase , InvalidRemoveMessage )
@@ -231,27 +244,27 @@ function invalidInputsResponsePrompt(testCase, InvalidInputsResponseMessage)
231244function invalidFunctionPrompt = iGetInvalidFunctionPrompt
232245 invalidFunctionPrompt = struct( ...
233246 " NonStringInputName" , ...
234- struct(" Input" , {{123 , " content" }}, ...
247+ struct(" Input" , {{" 123 " , 123 , " content" }}, ...
235248 " Error" , " MATLAB:validators:mustBeNonzeroLengthText" ), ...
236249 ...
237250 " NonStringInputContent" , ...
238- struct(" Input" , {{" name" , 123 }}, ...
251+ struct(" Input" , {{" 123 " , " name" , 123 }}, ...
239252 " Error" , " MATLAB:validators:mustBeNonzeroLengthText" ), ...
240253 ...
241254 " EmptytName" , ...
242- struct(" Input" , {{" " , " content" }}, ...
255+ struct(" Input" , {{" 123 " , " " , " content" }}, ...
243256 " Error" , " MATLAB:validators:mustBeNonzeroLengthText" ), ...
244257 ...
245258 " EmptytContent" , ...
246- struct(" Input" , {{" name" , " " }}, ...
259+ struct(" Input" , {{" 123 " , " name" , " " }}, ...
247260 " Error" , " MATLAB:validators:mustBeNonzeroLengthText" ), ...
248261 ...
249262 " NonScalarInputName" , ...
250- struct(" Input" , {{[" name1" " name2" ], " content" }}, ...
263+ struct(" Input" , {{" 123 " , [" name1" " name2" ], " content" }}, ...
251264 " Error" , " MATLAB:validators:mustBeTextScalar" ),...
252265 ...
253266 " NonScalarInputContent" , ...
254- struct(" Input" , {{" name" , [" content1" , " content2" ]}}, ...
267+ struct(" Input" , {{" 123 " , " name" , [" content1" , " content2" ]}}, ...
255268 " Error" , " MATLAB:validators:mustBeTextScalar" ));
256269end
257270
0 commit comments