Skip to content

Commit d416d95

Browse files
Add tool param type validation
Signed-off-by: Nathalie Jonathan <nathhjo@amazon.com>
1 parent 5964268 commit d416d95

File tree

17 files changed

+358
-10
lines changed

17 files changed

+358
-10
lines changed

common/src/main/java/org/opensearch/ml/common/input/execute/tool/ToolMLInput.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ public class ToolMLInput extends MLInput {
3333
@Setter
3434
private String toolName;
3535

36+
@Getter
37+
private Map<String, Object> originalParameters;
38+
3639
public ToolMLInput(StreamInput in) throws IOException {
3740
super(in);
3841
this.toolName = in.readString();
@@ -66,7 +69,9 @@ public ToolMLInput(XContentParser parser, FunctionName functionName) throws IOEx
6669
toolName = parser.text();
6770
break;
6871
case PARAMETERS_FIELD:
69-
Map<String, String> parameters = StringUtils.getParameterMap(parser.map());
72+
Map<String, Object> rawParams = parser.map();
73+
originalParameters = rawParams;
74+
Map<String, String> parameters = StringUtils.getParameterMap(rawParams);
7075
inputDataset = new RemoteInferenceInputDataSet(parameters);
7176
break;
7277
default:

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tool/MLToolExecutor.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ public void execute(Input input, ActionListener<Output> listener) {
8585
try {
8686
Map<String, String> mutableParams = new HashMap<>(parameters);
8787
Tool tool = toolFactory.create(mutableParams);
88+
89+
// Validate original parameter types
90+
Map<String, Object> originalParameters = toolMLInput.getOriginalParameters();
91+
if (originalParameters != null && !tool.validateParameterTypes(originalParameters)) {
92+
listener.onFailure(new IllegalArgumentException("Invalid parameter types for tool: " + toolName));
93+
return;
94+
}
95+
8896
if (!tool.validate(mutableParams)) {
8997
listener.onFailure(new IllegalArgumentException("Invalid parameters for tool: " + toolName));
9098
return;

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,28 @@ public void setName(String s) {
117117

118118
@Override
119119
public boolean validate(Map<String, String> parameters) {
120+
if (parameters == null || parameters.isEmpty()) {
121+
return false;
122+
}
123+
124+
// Validate question length
125+
String question = parameters.get("question");
126+
if (question != null && question.length() > 10000) {
127+
throw new IllegalArgumentException("question length cannot exceed 10000 characters");
128+
}
129+
130+
return true;
131+
}
132+
133+
@Override
134+
public boolean validateParameterTypes(Map<String, Object> parameters) {
135+
// Validate question must be String
136+
Object questionObj = parameters.get("question");
137+
if (questionObj != null && !(questionObj instanceof String)) {
138+
throw new IllegalArgumentException(
139+
String.format("question must be a String type, but got %s", questionObj.getClass().getSimpleName())
140+
);
141+
}
120142
return true;
121143
}
122144

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,18 @@ public boolean validate(Map<String, String> parameters) {
100100
return parameters != null && !parameters.isEmpty();
101101
}
102102

103+
@Override
104+
public boolean validateParameterTypes(Map<String, Object> parameters) {
105+
// Validate response_filter must be String
106+
Object responseFilterObj = parameters.get("response_filter");
107+
if (responseFilterObj != null && !(responseFilterObj instanceof String)) {
108+
throw new IllegalArgumentException(
109+
String.format("response_filter must be a String type, but got %s", responseFilterObj.getClass().getSimpleName())
110+
);
111+
}
112+
return true;
113+
}
114+
103115
public static class Factory implements Tool.Factory<ConnectorTool> {
104116
public static final String TYPE = "ConnectorTool";
105117
public static final String DEFAULT_DESCRIPTION = "Invokes external service. Required: 'connector_id'. Returns: service response.";

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexInsightTool.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ public boolean validate(Map<String, String> parameters) {
113113
return true;
114114
}
115115

116+
@Override
117+
public boolean validateParameterTypes(Map<String, Object> parameters) {
118+
return true;
119+
}
120+
116121
public static class Factory implements Tool.Factory<IndexInsightTool> {
117122
private Client client;
118123

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,45 @@ public String getType() {
175175

176176
@Override
177177
public boolean validate(Map<String, String> parameters) {
178-
return parameters != null && !parameters.isEmpty() && parameters.containsKey("index");
178+
if (parameters == null || parameters.isEmpty() || !parameters.containsKey("index")) {
179+
return false;
180+
}
181+
182+
// Validate question length
183+
String question = parameters.get("question");
184+
if (question != null && question.length() > 10000) {
185+
throw new IllegalArgumentException("question length cannot exceed 10000 characters");
186+
}
187+
188+
return true;
189+
}
190+
191+
@Override
192+
public boolean validateParameterTypes(Map<String, Object> parameters) {
193+
// Validate question must be String
194+
Object questionObj = parameters.get("question");
195+
if (questionObj != null && !(questionObj instanceof String)) {
196+
throw new IllegalArgumentException(
197+
String.format("question must be a String type, but got %s", questionObj.getClass().getSimpleName())
198+
);
199+
}
200+
201+
// Validate index must be ArrayList
202+
Object indexObj = parameters.get("index");
203+
if (indexObj != null && !(indexObj instanceof ArrayList)) {
204+
throw new IllegalArgumentException(
205+
String.format("index must be an Array type, but got %s", indexObj.getClass().getSimpleName())
206+
);
207+
}
208+
209+
// Validate local must be Boolean
210+
Object localObj = parameters.get("local");
211+
if (localObj != null && !(localObj instanceof Boolean)) {
212+
throw new IllegalArgumentException(
213+
String.format("local must be a Boolean type, but got %s", localObj.getClass().getSimpleName())
214+
);
215+
}
216+
return true;
179217
}
180218

181219
/**

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,53 @@ public void onFailure(final Exception e) {
415415

416416
@Override
417417
public boolean validate(Map<String, String> parameters) {
418-
return parameters != null && !parameters.isEmpty();
418+
if (parameters == null || parameters.isEmpty()) {
419+
return false;
420+
}
421+
422+
// Validate question length
423+
String question = parameters.get("question");
424+
if (question != null && question.length() > 10000) {
425+
throw new IllegalArgumentException("question length cannot exceed 10000 characters");
426+
}
427+
428+
return true;
429+
}
430+
431+
@Override
432+
public boolean validateParameterTypes(Map<String, Object> parameters) {
433+
// Validate question must be String
434+
Object questionObj = parameters.get("question");
435+
if (questionObj != null && !(questionObj instanceof String)) {
436+
throw new IllegalArgumentException(
437+
String.format("question must be a String type, but got %s", questionObj.getClass().getSimpleName())
438+
);
439+
}
440+
441+
// Validate indices must be ArrayList
442+
Object indicesObj = parameters.get("indices");
443+
if (indicesObj != null && !(indicesObj instanceof ArrayList)) {
444+
throw new IllegalArgumentException(
445+
String.format("indices must be an Array type, but got %s", indicesObj.getClass().getSimpleName())
446+
);
447+
}
448+
449+
// Validate local must be Boolean
450+
Object localObj = parameters.get("local");
451+
if (localObj != null && !(localObj instanceof Boolean)) {
452+
throw new IllegalArgumentException(
453+
String.format("local must be a Boolean type, but got %s", localObj.getClass().getSimpleName())
454+
);
455+
}
456+
457+
// Validate page_size must be Integer
458+
Object pageSizeObj = parameters.get("page_size");
459+
if (pageSizeObj != null && !(pageSizeObj instanceof Integer)) {
460+
throw new IllegalArgumentException(
461+
String.format("page_size must be an Integer type, but got %s", pageSizeObj.getClass().getSimpleName())
462+
);
463+
}
464+
return true;
419465
}
420466

421467
/**

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,26 @@ public boolean validate(Map<String, String> parameters) {
148148
return parameters != null && !parameters.isEmpty();
149149
}
150150

151+
@Override
152+
public boolean validateParameterTypes(Map<String, Object> parameters) {
153+
// Validate prompt must be String
154+
Object promptObj = parameters.get("prompt");
155+
if (promptObj != null && !(promptObj instanceof String)) {
156+
throw new IllegalArgumentException(
157+
String.format("prompt must be a String type, but got %s", promptObj.getClass().getSimpleName())
158+
);
159+
}
160+
161+
// Validate response_field must be String
162+
Object responseFieldObj = parameters.get(RESPONSE_FIELD);
163+
if (responseFieldObj != null && !(responseFieldObj instanceof String)) {
164+
throw new IllegalArgumentException(
165+
String.format("%s must be a String type, but got %s", RESPONSE_FIELD, responseFieldObj.getClass().getSimpleName())
166+
);
167+
}
168+
return true;
169+
}
170+
151171
public static class Factory implements WithModelTool.Factory<MLModelTool> {
152172
private Client client;
153173

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/McpSseTool.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ public boolean validate(Map<String, String> parameters) {
9999
return true;
100100
}
101101

102+
@Override
103+
public boolean validateParameterTypes(Map<String, Object> parameters) {
104+
return true;
105+
}
106+
102107
public static class Factory implements WithModelTool.Factory<McpSseTool> {
103108
private static Factory INSTANCE;
104109

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/McpStreamableHttpTool.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ public boolean validate(Map<String, String> parameters) {
100100
return true;
101101
}
102102

103+
@Override
104+
public boolean validateParameterTypes(Map<String, Object> parameters) {
105+
return true;
106+
}
107+
103108
public static class Factory implements WithModelTool.Factory<McpStreamableHttpTool> {
104109
private static Factory INSTANCE;
105110

0 commit comments

Comments
 (0)