Skip to content

Commit f538803

Browse files
committed
add boolean, number, and enum argument validation to functions
1 parent c019530 commit f538803

File tree

3 files changed

+192
-1
lines changed

3 files changed

+192
-1
lines changed

src/main/kotlin/com/cjcrafter/openai/chat/tool/FunctionCall.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,14 @@ data class FunctionCall(
7676
when (property.type) {
7777
"integer" -> if (!value.isInt)
7878
throw HallucinationException("Expected an integer for argument $key")
79-
"enum" -> if (!value.isTextual)
79+
"number" -> if (!value.isDouble && !value.isInt)
80+
throw HallucinationException("Expected a number for argument $key")
81+
"boolean" -> if (!value.isBoolean)
82+
throw HallucinationException("Expected a boolean for argument $key")
83+
"string" -> if (!value.isTextual)
8084
throw HallucinationException("Expected a string for argument $key")
85+
"enum" -> if (!value.isTextual || !property.enum!!.contains(value.asText()))
86+
throw HallucinationException("Expected one of ${property.enum}, got $key")
8187
}
8288
} ?: throw HallucinationException("Unknown argument: $key")
8389
}

src/main/kotlin/com/cjcrafter/openai/chat/tool/FunctionTool.kt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,13 @@ data class FunctionTool internal constructor(
117117
parameters!!.require(name)
118118
}
119119

120+
/**
121+
* Makes this function take no parameters.
122+
*/
123+
fun noParameters() = apply {
124+
parameters = FunctionParameters()
125+
}
126+
120127
fun build() = FunctionTool(
121128
name = name ?: throw IllegalStateException("Name must be set"),
122129
parameters = parameters ?: throw IllegalStateException("Parameters must be set"),
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
package com.cjcrafter.openai.chat.tool
2+
3+
import com.cjcrafter.openai.OpenAI
4+
import com.cjcrafter.openai.exception.HallucinationException
5+
import com.fasterxml.jackson.module.kotlin.readValue
6+
import org.intellij.lang.annotations.Language
7+
import org.junit.jupiter.api.Assertions
8+
import org.junit.jupiter.api.Assertions.*
9+
import org.junit.jupiter.api.Test
10+
import org.junit.jupiter.api.assertDoesNotThrow
11+
import org.junit.jupiter.api.assertThrows
12+
13+
class FunctionCallTest {
14+
15+
@Test
16+
fun `test bad enum`() {
17+
val tools = listOf(
18+
functionTool {
19+
name("enum_checker")
20+
description("This function is used to test the enum parameter")
21+
addEnumParameter("enum", mutableListOf("a", "b", "c"))
22+
}.toTool()
23+
)
24+
@Language("json")
25+
val json = "{\"name\": \"enum_checker\", \"arguments\": \"{\\\"enum\\\": \\\"d\\\"}\"}" // d is not a valid enum
26+
val call = FunctionCall("enum_checker", json)
27+
28+
assertThrows<HallucinationException> {
29+
call.tryParseArguments(tools)
30+
}
31+
}
32+
33+
@Test
34+
fun `test good enum`() {
35+
val tools = listOf(
36+
functionTool {
37+
name("enum_checker")
38+
description("This function is used to test the enum parameter")
39+
addEnumParameter("enum", mutableListOf("a", "b", "c"))
40+
}.toTool()
41+
)
42+
@Language("json")
43+
val json = "{\"name\": \"enum_checker\", \"arguments\": \"{\\\"enum\\\": \\\"a\\\"}\"}" // a is a valid enum
44+
val call = OpenAI.createObjectMapper().readValue<FunctionCall>(json)
45+
46+
val args = call.tryParseArguments(tools)
47+
assertTrue(args.contains("enum")) { "enum should be present in the arguments" }
48+
assertEquals("a", args["enum"]?.asText())
49+
}
50+
51+
@Test
52+
fun `test bad integer`() {
53+
val tools = listOf(
54+
functionTool {
55+
name("integer_checker")
56+
description("This function is used to test the integer parameter")
57+
addIntegerParameter("integer", "test parameter")
58+
}.toTool()
59+
)
60+
@Language("json")
61+
val json = "{\"name\": \"integer_checker\", \"arguments\": \"{\\\"integer\\\": \\\"not an integer\\\"}\"}" // not an integer
62+
val call = OpenAI.createObjectMapper().readValue<FunctionCall>(json)
63+
64+
assertThrows<HallucinationException> {
65+
call.tryParseArguments(tools)
66+
}
67+
}
68+
69+
@Test
70+
fun `test good integer`() {
71+
val tools = listOf(
72+
functionTool {
73+
name("integer_checker")
74+
description("This function is used to test the integer parameter")
75+
addIntegerParameter("integer", "test parameter")
76+
}.toTool()
77+
)
78+
@Language("json")
79+
val json = "{\"name\": \"integer_checker\", \"arguments\": \"{\\\"integer\\\": 1}\"}" // 1 is an integer
80+
val call = OpenAI.createObjectMapper().readValue<FunctionCall>(json)
81+
82+
val args = call.tryParseArguments(tools)
83+
assertTrue(args.contains("integer")) { "integer should be present in the arguments" }
84+
assertEquals(1, args["integer"]?.asInt())
85+
}
86+
87+
@Test
88+
fun `test bad boolean`() {
89+
val tools = listOf(
90+
functionTool {
91+
name("boolean_checker")
92+
description("This function is used to test the boolean parameter")
93+
addBooleanParameter("is_true", "test parameter")
94+
}.toTool()
95+
)
96+
@Language("json")
97+
val json = "{\"name\": \"boolean_checker\", \"arguments\": \"{\\\"boolean\\\": \\\"not a boolean\\\"}\"}" // not a boolean
98+
val call = OpenAI.createObjectMapper().readValue<FunctionCall>(json)
99+
100+
assertThrows<HallucinationException> {
101+
call.tryParseArguments(tools)
102+
}
103+
}
104+
105+
@Test
106+
fun `test good boolean`() {
107+
val tools = listOf(
108+
functionTool {
109+
name("boolean_checker")
110+
description("This function is used to test the boolean parameter")
111+
addBooleanParameter("is_true", "test parameter")
112+
}.toTool()
113+
)
114+
@Language("json")
115+
val json = "{\"name\": \"boolean_checker\", \"arguments\": \"{\\\"is_true\\\": true}\"}" // true is a boolean
116+
val call = OpenAI.createObjectMapper().readValue<FunctionCall>(json)
117+
118+
val args = call.tryParseArguments(tools)
119+
assertTrue(args.contains("is_true")) { "is_true should be present in the arguments" }
120+
assertTrue(args["is_true"]?.asBoolean() ?: false)
121+
}
122+
123+
@Test
124+
fun `test missing required parameters`() {
125+
val tools = listOf(
126+
functionTool {
127+
name("required_parameter_function")
128+
description("This function is used to test the required parameter")
129+
addIntegerParameter("required", "test parameter", required = true)
130+
addBooleanParameter("not_required", "test parameter")
131+
}.toTool()
132+
)
133+
@Language("json")
134+
val json = "{\"name\": \"required_parameter_function\", \"arguments\": \"{\\\"not_required\\\": true}\"}" // missing required parameter
135+
val call = OpenAI.createObjectMapper().readValue<FunctionCall>(json)
136+
137+
assertThrows<HallucinationException> {
138+
call.tryParseArguments(tools)
139+
}
140+
}
141+
142+
@Test
143+
fun `test has required parameter`() {
144+
val tools = listOf(
145+
functionTool {
146+
name("required_parameter_function")
147+
description("This function is used to test the required parameter")
148+
addIntegerParameter("required", "test parameter", required = true)
149+
addBooleanParameter("not_required", "test parameter")
150+
}.toTool()
151+
)
152+
@Language("json")
153+
val json = "{\"name\": \"required_parameter_function\", \"arguments\": \"{\\\"required\\\": 1, \\\"not_required\\\": true}\"}" // has required parameter
154+
val call = OpenAI.createObjectMapper().readValue<FunctionCall>(json)
155+
156+
assertDoesNotThrow("Should not throw an exception when the required parameter is present") {
157+
call.tryParseArguments(tools)
158+
}
159+
}
160+
161+
@Test
162+
fun `test invalid function name`() {
163+
val tools = listOf(
164+
functionTool {
165+
name("function_name_checker")
166+
description("This function is used to test the function name")
167+
noParameters()
168+
}.toTool()
169+
)
170+
@Language("json")
171+
val json = "{\"name\": \"invalid_function_name\", \"arguments\": \"{}\"}" // invalid function name
172+
val call = OpenAI.createObjectMapper().readValue<FunctionCall>(json)
173+
174+
assertThrows<HallucinationException> {
175+
call.tryParseArguments(tools)
176+
}
177+
}
178+
}

0 commit comments

Comments
 (0)