Skip to content

Commit d6fbeef

Browse files
committed
Bug fixes.
1 parent da1c112 commit d6fbeef

File tree

5 files changed

+43
-32
lines changed

5 files changed

+43
-32
lines changed

+llms/+utils/errorMessageCatalog.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
catalog("llms:assistantMustHaveTextNameAndArguments") = "Fields 'name' and 'arguments' must be text with one or more characters.";
4646
catalog("llms:mustBeValidIndex") = "Value is larger than the number of elements in Messages ({1}).";
4747
catalog("llms:stopSequencesMustHaveMax4Elements") = "Number of elements must not be larger than 4.";
48-
catalog("llms:keyMustBeSpecified") = "API key not found as environment variable OPEN_API_KEY and not specified via ApiKey parameter.";
48+
catalog("llms:keyMustBeSpecified") = "API key not found as environment variable OPENAI_API_KEY and not specified via ApiKey parameter.";
4949
catalog("llms:mustHaveMessages") = "Value must contain at least one message in Messages.";
5050
catalog("llms:mustSetFunctionsForCall") = "When no functions are defined, FunctionCall must not be specified.";
5151
catalog("llms:mustBeMessagesOrTxt") = "Messages must be text with one or more characters or an openAIMessages objects.";

openAIChat.m

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@
5252
%
5353
% SystemPrompt - System prompt.
5454
%
55-
% AvailableModels - List of available models.
56-
%
5755
% FunctionNames - Names of the functions that the model can
5856
% request calls.
5957

@@ -93,25 +91,18 @@
9391
ApiKey
9492
end
9593

96-
properties(Constant)
97-
%AVAILABLEMODELS List of available models.
98-
AvailableModels = ["gpt-4", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613",...
99-
"gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k",...
100-
"gpt-3.5-turbo-16k-0613"]
101-
end
102-
10394
methods
10495
function this = openAIChat(systemPrompt, nvp)
10596
arguments
10697
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
10798
nvp.Functions (1,:) {mustBeA(nvp.Functions, "openAIFunction")} = openAIFunction.empty
108-
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["gpt-4", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613",...
99+
nvp.ModelName (1,1) {mustBeMember(nvp.ModelName,["gpt-4", "gpt-4-0613", "gpt-4-32k", ...
109100
"gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k",...
110101
"gpt-3.5-turbo-16k-0613"])} = "gpt-3.5-turbo"
111102
nvp.Temperature (1,1) {mustBeValidTemperature} = 1
112103
nvp.TopProbabilityMass (1,1) {mustBeValidTopP} = 1
113104
nvp.StopSequences (1,:) {mustBeValidStop} = {}
114-
nvp.ApiKey (1,1) {mustBeNonzeroLengthText}
105+
nvp.ApiKey {mustBeNonzeroLengthTextScalar}
115106
nvp.PresencePenalty (1,1) {mustBeValidPenalty} = 0
116107
nvp.FrequencyPenalty (1,1) {mustBeValidPenalty} = 0
117108
end
@@ -249,6 +240,10 @@ function mustBeValidFunctionCall(this, functionCall)
249240
end
250241
end
251242

243+
function mustBeNonzeroLengthTextScalar(content)
244+
mustBeNonzeroLengthText(content)
245+
mustBeTextScalar(content)
246+
end
252247

253248
function [functionsStruct, functionNames] = functionAsStruct(functions)
254249
numFunctions = numel(functions);
@@ -268,7 +263,7 @@ function mustBeValidMsgs(value)
268263
end
269264
else
270265
try
271-
mustBeNonzeroLengthText(value);
266+
mustBeNonzeroLengthTextScalar(value);
272267
catch ME
273268
error("llms:mustBeMessagesOrTxt", llms.utils.errorMessageCatalog.getMessage("llms:mustBeMessagesOrTxt"));
274269
end

openAIMessages.m

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@
3939

4040
arguments
4141
this (1,1) openAIMessages
42-
name (1,1) {mustBeNonzeroLengthText}
43-
content (1,1) {mustBeNonzeroLengthText}
42+
name {mustBeNonzeroLengthTextScalar}
43+
content {mustBeNonzeroLengthTextScalar}
4444
end
4545

46-
newMessage = struct("role", "system", "name", name, "content", content);
46+
newMessage = struct("role", "system", "name", string(name), "content", string(content));
4747
this.Messages{end+1} = newMessage;
4848
end
4949

@@ -62,10 +62,10 @@
6262

6363
arguments
6464
this (1,1) openAIMessages
65-
content (1,1) {mustBeNonzeroLengthText}
65+
content {mustBeNonzeroLengthTextScalar}
6666
end
6767

68-
newMessage = struct("role", "user", "content", content);
68+
newMessage = struct("role", "user", "content", string(content));
6969
this.Messages{end+1} = newMessage;
7070
end
7171

@@ -86,11 +86,11 @@
8686

8787
arguments
8888
this (1,1) openAIMessages
89-
name (1,1) {mustBeNonzeroLengthText}
90-
content (1,1) {mustBeNonzeroLengthText}
89+
name {mustBeNonzeroLengthTextScalar}
90+
content {mustBeNonzeroLengthTextScalar}
9191
end
9292

93-
newMessage = struct("role", "function", "name", name, "content", content);
93+
newMessage = struct("role", "function", "name", string(name), "content", string(content));
9494
this.Messages{end+1} = newMessage;
9595
end
9696

@@ -133,7 +133,7 @@
133133
if isfield(messageStruct, "function_call")
134134
funCall = messageStruct.function_call;
135135
validateAssistantWithFunctionCall(funCall)
136-
this = addAssistantMessage(this,funCall.name, funCall.arguments);
136+
this = addAssistantMessage(this, funCall.name, funCall.arguments);
137137
else
138138
% Simple assistant response
139139
validateRegularAssistant(messageStruct.content);
@@ -197,6 +197,11 @@
197197
end
198198
end
199199

200+
function mustBeNonzeroLengthTextScalar(content)
201+
mustBeNonzeroLengthText(content)
202+
mustBeTextScalar(content)
203+
end
204+
200205
function validateRegularAssistant(content)
201206
try
202207
mustBeNonzeroLengthText(content)

tests/topenAIChat.m

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ function saveEnvVar(testCase)
2626
function generateAcceptsSingleStringAsInput(testCase)
2727
chat = openAIChat(ApiKey="this-is-not-a-real-key");
2828
testCase.verifyWarningFree(@()generate(chat,"This is okay"));
29+
chat = openAIChat(ApiKey='this-is-not-a-real-key');
30+
testCase.verifyWarningFree(@()generate(chat,"This is okay"));
2931
end
3032

3133
function generateAcceptsMessagesAsInput(testCase)
@@ -307,7 +309,7 @@ function assignValueToProperty(property, value)
307309
...
308310
"InvalidApiKeySize",struct( ...
309311
"Input",{{ "ApiKey" ["abc" "abc"] }},...
310-
"Error","MATLAB:validation:IncompatibleSize"));
312+
"Error","MATLAB:validators:mustBeTextScalar"));
311313
end
312314

313315
function invalidGenerateInput = iGetInvalidGenerateInput
@@ -354,5 +356,4 @@ function assignValueToProperty(property, value)
354356
"InvalidFunctionCallSize",struct( ...
355357
"Input",{{ validMessages "FunctionCall" ["validfunction", "validfunction"] }},...
356358
"Error","MATLAB:validators:mustBeTextScalar"));
357-
end
358-
359+
end

tests/topenAIMessages.m

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
InvalidInputsSystemPrompt = iGetInvalidInputsSystemPrompt;
1010
InvalidInputsResponseMessage = iGetInvalidInputsResponseMessage;
1111
InvalidRemoveMessage = iGetInvalidRemoveMessage;
12+
ValidTextInput = {"This is okay"; 'this is ok'};
1213
end
1314

1415
methods(Test)
@@ -17,6 +18,15 @@ function constructorStartsWithEmptyMessages(testCase)
1718
testCase.verifyTrue(isempty(msgs.Messages));
1819
end
1920

21+
function differentInputTextAccepted(testCase, ValidTextInput)
22+
msgs = openAIMessages;
23+
testCase.verifyWarningFree(@()addSystemMessage(msgs, ValidTextInput, ValidTextInput));
24+
testCase.verifyWarningFree(@()addSystemMessage(msgs, ValidTextInput, ValidTextInput));
25+
testCase.verifyWarningFree(@()addUserMessage(msgs, ValidTextInput));
26+
testCase.verifyWarningFree(@()addFunctionMessage(msgs, ValidTextInput, ValidTextInput));
27+
end
28+
29+
2030
function systemMessageIsAdded(testCase)
2131
prompt = "Here is a system prompt";
2232
name = "example";
@@ -56,7 +66,7 @@ function assistantFunctionCallMessageIsAdded(testCase)
5666
msgs = openAIMessages;
5767
args = "{""arg1"": 1, ""arg2"": 2, ""arg3"": ""3""}";
5868
funCall = struct("name", functionName, "arguments", args);
59-
functionCallPrompt = struct("role", "assistant", "content", [], "function_call", funCall);
69+
functionCallPrompt = struct("role", "assistant", "content", "", "function_call", funCall);
6070
msgs = addResponseMessage(msgs, functionCallPrompt);
6171
testCase.verifyEqual(msgs.Messages{1}, functionCallPrompt);
6272
end
@@ -65,7 +75,7 @@ function assistantFunctionCallMessageWithoutArgsIsAdded(testCase)
6575
functionName = "functionName";
6676
msgs = openAIMessages;
6777
funCall = struct("name", functionName, "arguments", "{}");
68-
functionCallPrompt = struct("role", "assistant", "content", [], "function_call", funCall);
78+
functionCallPrompt = struct("role", "assistant", "content", "", "function_call", funCall);
6979
msgs = addResponseMessage(msgs, functionCallPrompt);
7080
testCase.verifyEqual(msgs.Messages{1}, functionCallPrompt);
7181
end
@@ -145,11 +155,11 @@ function invalidInputsResponsePrompt(testCase, InvalidInputsResponseMessage)
145155
...
146156
"NonScalarInputName", ...
147157
struct("Input", {{["name1" "name2"], "content"}}, ...
148-
"Error", "MATLAB:validation:IncompatibleSize"),...
158+
"Error", "MATLAB:validators:mustBeTextScalar"),...
149159
...
150160
"NonScalarInputContent", ...
151161
struct("Input", {{"name", ["content1", "content2"]}}, ...
152-
"Error", "MATLAB:validation:IncompatibleSize"));
162+
"Error", "MATLAB:validators:mustBeTextScalar"));
153163
end
154164

155165
function invalidInputsUserPrompt = iGetInvalidInputsUserPrompt
@@ -160,7 +170,7 @@ function invalidInputsResponsePrompt(testCase, InvalidInputsResponseMessage)
160170
...
161171
"NonScalarInput", ...
162172
struct("Input", {{["prompt1" "prompt2"]}}, ...
163-
"Error", "MATLAB:validation:IncompatibleSize"), ...
173+
"Error", "MATLAB:validators:mustBeTextScalar"), ...
164174
...
165175
"EmptyInput", ...
166176
struct("Input", {{""}}, ...
@@ -187,11 +197,11 @@ function invalidInputsResponsePrompt(testCase, InvalidInputsResponseMessage)
187197
...
188198
"NonScalarInputName", ...
189199
struct("Input", {{["name1" "name2"], "content"}}, ...
190-
"Error", "MATLAB:validation:IncompatibleSize"),...
200+
"Error", "MATLAB:validators:mustBeTextScalar"),...
191201
...
192202
"NonScalarInputContent", ...
193203
struct("Input", {{"name", ["content1", "content2"]}}, ...
194-
"Error", "MATLAB:validation:IncompatibleSize"));
204+
"Error", "MATLAB:validators:mustBeTextScalar"));
195205
end
196206

197207
function invalidRemoveMessage = iGetInvalidRemoveMessage

0 commit comments

Comments
 (0)