|
5 | 5 |
|
6 | 6 | package org.opensearch.ml.engine.algorithms.contextmanager; |
7 | 7 |
|
| 8 | +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER; |
| 9 | + |
8 | 10 | import java.util.ArrayList; |
9 | 11 | import java.util.HashMap; |
10 | 12 | import java.util.List; |
|
16 | 18 | import org.mockito.Mock; |
17 | 19 | import org.mockito.MockitoAnnotations; |
18 | 20 | import org.opensearch.ml.common.contextmanager.ContextManagerContext; |
| 21 | +import org.opensearch.ml.common.output.model.ModelTensor; |
| 22 | +import org.opensearch.ml.common.output.model.ModelTensorOutput; |
| 23 | +import org.opensearch.ml.common.output.model.ModelTensors; |
| 24 | +import org.opensearch.ml.common.transport.MLTaskResponse; |
19 | 25 | import org.opensearch.transport.client.Client; |
20 | 26 |
|
21 | 27 | /** |
@@ -161,6 +167,159 @@ public void testProcessSummarizationResult() { |
161 | 167 | Assert.assertTrue(firstOutput.contains("Test summary")); |
162 | 168 | } |
163 | 169 |
|
| 170 | + @Test |
| 171 | + public void testExtractSummaryFromResponseWithLLMResponseFilter() throws Exception { |
| 172 | + Map<String, Object> config = new HashMap<>(); |
| 173 | + manager.initialize(config); |
| 174 | + |
| 175 | + // Set up context with LLM_RESPONSE_FILTER |
| 176 | + Map<String, String> parameters = new HashMap<>(); |
| 177 | + parameters.put(LLM_RESPONSE_FILTER, "$.choices[0].message.content"); |
| 178 | + context.setParameters(parameters); |
| 179 | + |
| 180 | + // Create mock response with OpenAI-style structure |
| 181 | + Map<String, Object> responseData = new HashMap<>(); |
| 182 | + Map<String, Object> choice = new HashMap<>(); |
| 183 | + Map<String, Object> message = new HashMap<>(); |
| 184 | + message.put("content", "This is the extracted summary content"); |
| 185 | + choice.put("message", message); |
| 186 | + responseData.put("choices", List.of(choice)); |
| 187 | + |
| 188 | + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); |
| 189 | + |
| 190 | + // Use reflection to access the private method |
| 191 | + java.lang.reflect.Method extractMethod = SummarizationManager.class |
| 192 | + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); |
| 193 | + extractMethod.setAccessible(true); |
| 194 | + |
| 195 | + String result = (String) extractMethod.invoke(manager, mockResponse, context); |
| 196 | + |
| 197 | + Assert.assertEquals("This is the extracted summary content", result); |
| 198 | + } |
| 199 | + |
| 200 | + @Test |
| 201 | + public void testExtractSummaryFromResponseWithBedrockResponseFilter() throws Exception { |
| 202 | + Map<String, Object> config = new HashMap<>(); |
| 203 | + manager.initialize(config); |
| 204 | + |
| 205 | + // Set up context with Bedrock-style LLM_RESPONSE_FILTER |
| 206 | + Map<String, String> parameters = new HashMap<>(); |
| 207 | + parameters.put(LLM_RESPONSE_FILTER, "$.output.message.content[0].text"); |
| 208 | + context.setParameters(parameters); |
| 209 | + |
| 210 | + // Create mock response with Bedrock-style structure |
| 211 | + Map<String, Object> responseData = new HashMap<>(); |
| 212 | + Map<String, Object> output = new HashMap<>(); |
| 213 | + Map<String, Object> message = new HashMap<>(); |
| 214 | + Map<String, Object> content = new HashMap<>(); |
| 215 | + content.put("text", "Bedrock extracted summary"); |
| 216 | + message.put("content", List.of(content)); |
| 217 | + output.put("message", message); |
| 218 | + responseData.put("output", output); |
| 219 | + |
| 220 | + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); |
| 221 | + |
| 222 | + // Use reflection to access the private method |
| 223 | + java.lang.reflect.Method extractMethod = SummarizationManager.class |
| 224 | + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); |
| 225 | + extractMethod.setAccessible(true); |
| 226 | + |
| 227 | + String result = (String) extractMethod.invoke(manager, mockResponse, context); |
| 228 | + |
| 229 | + Assert.assertEquals("Bedrock extracted summary", result); |
| 230 | + } |
| 231 | + |
| 232 | + @Test |
| 233 | + public void testExtractSummaryFromResponseWithInvalidFilter() throws Exception { |
| 234 | + Map<String, Object> config = new HashMap<>(); |
| 235 | + manager.initialize(config); |
| 236 | + |
| 237 | + // Set up context with invalid LLM_RESPONSE_FILTER path |
| 238 | + Map<String, String> parameters = new HashMap<>(); |
| 239 | + parameters.put(LLM_RESPONSE_FILTER, "$.invalid.path"); |
| 240 | + context.setParameters(parameters); |
| 241 | + |
| 242 | + // Create mock response with simple structure |
| 243 | + Map<String, Object> responseData = new HashMap<>(); |
| 244 | + responseData.put("response", "Fallback summary content"); |
| 245 | + |
| 246 | + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); |
| 247 | + |
| 248 | + // Use reflection to access the private method |
| 249 | + java.lang.reflect.Method extractMethod = SummarizationManager.class |
| 250 | + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); |
| 251 | + extractMethod.setAccessible(true); |
| 252 | + |
| 253 | + String result = (String) extractMethod.invoke(manager, mockResponse, context); |
| 254 | + |
| 255 | + // Should fall back to default parsing |
| 256 | + Assert.assertEquals("Fallback summary content", result); |
| 257 | + } |
| 258 | + |
| 259 | + @Test |
| 260 | + public void testExtractSummaryFromResponseWithoutFilter() throws Exception { |
| 261 | + Map<String, Object> config = new HashMap<>(); |
| 262 | + manager.initialize(config); |
| 263 | + |
| 264 | + // Context without LLM_RESPONSE_FILTER |
| 265 | + Map<String, String> parameters = new HashMap<>(); |
| 266 | + context.setParameters(parameters); |
| 267 | + |
| 268 | + // Create mock response with simple structure |
| 269 | + Map<String, Object> responseData = new HashMap<>(); |
| 270 | + responseData.put("response", "Default parsed summary"); |
| 271 | + |
| 272 | + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); |
| 273 | + |
| 274 | + // Use reflection to access the private method |
| 275 | + java.lang.reflect.Method extractMethod = SummarizationManager.class |
| 276 | + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); |
| 277 | + extractMethod.setAccessible(true); |
| 278 | + |
| 279 | + String result = (String) extractMethod.invoke(manager, mockResponse, context); |
| 280 | + |
| 281 | + Assert.assertEquals("Default parsed summary", result); |
| 282 | + } |
| 283 | + |
| 284 | + @Test |
| 285 | + public void testExtractSummaryFromResponseWithEmptyFilter() throws Exception { |
| 286 | + Map<String, Object> config = new HashMap<>(); |
| 287 | + manager.initialize(config); |
| 288 | + |
| 289 | + // Set up context with empty LLM_RESPONSE_FILTER |
| 290 | + Map<String, String> parameters = new HashMap<>(); |
| 291 | + parameters.put(LLM_RESPONSE_FILTER, ""); |
| 292 | + context.setParameters(parameters); |
| 293 | + |
| 294 | + // Create mock response |
| 295 | + Map<String, Object> responseData = new HashMap<>(); |
| 296 | + responseData.put("response", "Empty filter fallback"); |
| 297 | + |
| 298 | + MLTaskResponse mockResponse = createMockMLTaskResponse(responseData); |
| 299 | + |
| 300 | + // Use reflection to access the private method |
| 301 | + java.lang.reflect.Method extractMethod = SummarizationManager.class |
| 302 | + .getDeclaredMethod("extractSummaryFromResponse", MLTaskResponse.class, ContextManagerContext.class); |
| 303 | + extractMethod.setAccessible(true); |
| 304 | + |
| 305 | + String result = (String) extractMethod.invoke(manager, mockResponse, context); |
| 306 | + |
| 307 | + Assert.assertEquals("Empty filter fallback", result); |
| 308 | + } |
| 309 | + |
| 310 | + /** |
| 311 | + * Helper method to create a mock MLTaskResponse with the given data. |
| 312 | + */ |
| 313 | + private MLTaskResponse createMockMLTaskResponse(Map<String, Object> responseData) { |
| 314 | + ModelTensor tensor = ModelTensor.builder().dataAsMap(responseData).build(); |
| 315 | + |
| 316 | + ModelTensors tensors = ModelTensors.builder().mlModelTensors(List.of(tensor)).build(); |
| 317 | + |
| 318 | + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build(); |
| 319 | + |
| 320 | + return MLTaskResponse.builder().output(output).build(); |
| 321 | + } |
| 322 | + |
164 | 323 | /** |
165 | 324 | * Helper method to add tool interactions to the context. |
166 | 325 | */ |
|
0 commit comments