Skip to content

Commit c776066

Browse files
committed
add: test cases
Signed-off-by: Jiaru Jiang <jiaruj@amazon.com>
1 parent e5ab87b commit c776066

File tree

3 files changed

+180
-2
lines changed

3 files changed

+180
-2
lines changed

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,4 +1331,41 @@ public void testGenerateLLMSummaryWithNullSteps() {
13311331

13321332
verify(listener).onFailure(any(IllegalArgumentException.class));
13331333
}
1334+
1335+
@Test
1336+
public void testExtractSummaryFromResponse_WithResponseField() {
1337+
Map<String, Object> dataMap = new HashMap<>();
1338+
dataMap.put("response", "Summary from response field");
1339+
ModelTensor tensor = ModelTensor.builder().dataAsMap(dataMap).build();
1340+
ModelTensors tensors = ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build();
1341+
ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(tensors)).build();
1342+
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
1343+
1344+
String result = mlChatAgentRunner.extractSummaryFromResponse(response);
1345+
assertEquals("Summary from response field", result);
1346+
}
1347+
1348+
@Test
1349+
public void testExtractSummaryFromResponse_WithNullDataMap() {
1350+
ModelTensor tensor = ModelTensor.builder().build();
1351+
ModelTensors tensors = ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build();
1352+
ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(tensors)).build();
1353+
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
1354+
1355+
String result = mlChatAgentRunner.extractSummaryFromResponse(response);
1356+
assertEquals(null, result);
1357+
}
1358+
1359+
@Test
1360+
public void testExtractSummaryFromResponse_WithEmptyDataMap() {
1361+
Map<String, Object> dataMap = new HashMap<>();
1362+
dataMap.put("other_field", "some value");
1363+
ModelTensor tensor = ModelTensor.builder().dataAsMap(dataMap).build();
1364+
ModelTensors tensors = ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build();
1365+
ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(tensors)).build();
1366+
MLTaskResponse response = MLTaskResponse.builder().output(output).build();
1367+
1368+
String result = mlChatAgentRunner.extractSummaryFromResponse(response);
1369+
assertEquals(null, result);
1370+
}
13341371
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunnerTest.java

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,4 +1001,145 @@ public void testExecutionWithNullStepResult() {
10011001
// Verify that onFailure was called with the expected exception
10021002
verify(agentActionListener).onFailure(any(IllegalStateException.class));
10031003
}
1004+
1005+
@Test
1006+
public void testMaxStepsWithSingleCompletedStep() {
1007+
MLAgent mlAgent = createMLAgentWithTools();
1008+
1009+
doAnswer(invocation -> {
1010+
ActionListener<List<Interaction>> listener = invocation.getArgument(0);
1011+
listener.onResponse(Arrays.asList(Interaction.builder().id("i1").input("step1").response("").build()));
1012+
return null;
1013+
}).when(conversationIndexMemory).getMessages(any(), anyInt());
1014+
1015+
doAnswer(invocation -> {
1016+
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
1017+
listener.onResponse(updateResponse);
1018+
return null;
1019+
}).when(mlMemoryManager).updateInteraction(any(), any(), any());
1020+
1021+
Map<String, String> params = new HashMap<>();
1022+
params.put("question", "test");
1023+
params.put("parent_interaction_id", "pid");
1024+
params.put("max_steps", "0");
1025+
mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener);
1026+
1027+
verify(agentActionListener).onResponse(objectCaptor.capture());
1028+
String response = (String) ((ModelTensorOutput) objectCaptor.getValue())
1029+
.getMlModelOutputs()
1030+
.get(1)
1031+
.getMlModelTensors()
1032+
.get(0)
1033+
.getDataAsMap()
1034+
.get("response");
1035+
assertTrue(response.contains("Max Steps Limit (0) Reached"));
1036+
}
1037+
1038+
@Test
1039+
public void testSummaryExtractionWithResultField() {
1040+
MLAgent mlAgent = createMLAgentWithTools();
1041+
1042+
doAnswer(invocation -> {
1043+
ActionListener<Object> listener = invocation.getArgument(2);
1044+
ModelTensor tensor = ModelTensor.builder().result("Summary from result").build();
1045+
when(mlTaskResponse.getOutput())
1046+
.thenReturn(
1047+
ModelTensorOutput
1048+
.builder()
1049+
.mlModelOutputs(Arrays.asList(ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build()))
1050+
.build()
1051+
);
1052+
listener.onResponse(mlTaskResponse);
1053+
return null;
1054+
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any());
1055+
1056+
doAnswer(invocation -> {
1057+
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
1058+
listener.onResponse(updateResponse);
1059+
return null;
1060+
}).when(mlMemoryManager).updateInteraction(any(), any(), any());
1061+
1062+
Map<String, String> params = new HashMap<>();
1063+
params.put("question", "test");
1064+
params.put("parent_interaction_id", "pid");
1065+
params.put("max_steps", "0");
1066+
mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener);
1067+
1068+
verify(agentActionListener).onResponse(objectCaptor.capture());
1069+
String response = (String) ((ModelTensorOutput) objectCaptor.getValue())
1070+
.getMlModelOutputs()
1071+
.get(1)
1072+
.getMlModelTensors()
1073+
.get(0)
1074+
.getDataAsMap()
1075+
.get("response");
1076+
assertTrue(response.contains("Summary from result"));
1077+
}
1078+
1079+
@Test
1080+
public void testSummaryExtractionWithEmptyResponse() {
1081+
MLAgent mlAgent = createMLAgentWithTools();
1082+
1083+
doAnswer(invocation -> {
1084+
ActionListener<Object> listener = invocation.getArgument(2);
1085+
ModelTensor tensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", " ")).build();
1086+
when(mlTaskResponse.getOutput())
1087+
.thenReturn(
1088+
ModelTensorOutput
1089+
.builder()
1090+
.mlModelOutputs(Arrays.asList(ModelTensors.builder().mlModelTensors(Arrays.asList(tensor)).build()))
1091+
.build()
1092+
);
1093+
listener.onResponse(mlTaskResponse);
1094+
return null;
1095+
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any());
1096+
1097+
doAnswer(invocation -> {
1098+
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
1099+
listener.onResponse(updateResponse);
1100+
return null;
1101+
}).when(mlMemoryManager).updateInteraction(any(), any(), any());
1102+
1103+
Map<String, String> params = new HashMap<>();
1104+
params.put("question", "test");
1105+
params.put("parent_interaction_id", "pid");
1106+
params.put("max_steps", "0");
1107+
mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener);
1108+
1109+
verify(agentActionListener).onResponse(objectCaptor.capture());
1110+
String response = (String) ((ModelTensorOutput) objectCaptor.getValue())
1111+
.getMlModelOutputs()
1112+
.get(1)
1113+
.getMlModelTensors()
1114+
.get(0)
1115+
.getDataAsMap()
1116+
.get("response");
1117+
assertTrue(response.contains("Max Steps Limit"));
1118+
}
1119+
1120+
@Test
1121+
public void testSummaryExtractionWithNullOutput() {
1122+
MLAgent mlAgent = createMLAgentWithTools();
1123+
1124+
doAnswer(invocation -> {
1125+
ActionListener<Object> listener = invocation.getArgument(2);
1126+
when(mlTaskResponse.getOutput()).thenReturn(null);
1127+
listener.onResponse(mlTaskResponse);
1128+
return null;
1129+
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(MLPredictionTaskRequest.class), any());
1130+
1131+
doAnswer(invocation -> {
1132+
ActionListener<UpdateResponse> listener = invocation.getArgument(2);
1133+
listener.onResponse(updateResponse);
1134+
return null;
1135+
}).when(mlMemoryManager).updateInteraction(any(), any(), any());
1136+
1137+
Map<String, String> params = new HashMap<>();
1138+
params.put("question", "test");
1139+
params.put("parent_interaction_id", "pid");
1140+
params.put("max_steps", "0");
1141+
mlPlanExecuteAndReflectAgentRunner.run(mlAgent, params, agentActionListener);
1142+
1143+
verify(agentActionListener).onResponse(any());
1144+
}
10041145
}

plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ public void test_Update_RSC_FeatureEnabled_TypeEnabled_SkipsLegacyValidation() t
464464
// Enable RSC fast-path.
465465
ResourceSharingClient rsc = mock(ResourceSharingClient.class);
466466
ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc);
467-
// when(rsc.isFeatureEnabledForType(any())).thenReturn(true);
467+
when(rsc.isFeatureEnabledForType(any())).thenReturn(true);
468468

469469
// No ACL changes in request (so even legacy would pass, but we won't go there).
470470
MLUpdateModelGroupRequest req = prepareRequest(null, null, null);
@@ -486,7 +486,7 @@ public void test_Update_RSC_FeatureEnabled_TypeDisabled_UsesLegacyValidation() t
486486
// RSC feature on, but type disabled → legacy path.
487487
ResourceSharingClient rsc = mock(ResourceSharingClient.class);
488488
ResourceSharingClientAccessor.getInstance().setResourceSharingClient(rsc);
489-
// when(rsc.isFeatureEnabledForType(any())).thenReturn(false);
489+
when(rsc.isFeatureEnabledForType(any())).thenReturn(false);
490490

491491
// Allow legacy validation to pass:
492492
// security/model-access-control enabled:

0 commit comments

Comments
 (0)