Skip to content

Commit 846a9ba

Browse files
committed
add code coverage
Signed-off-by: Mingshi Liu <mingshl@amazon.com>
1 parent fc0d896 commit 846a9ba

File tree

9 files changed

+434
-10
lines changed

9 files changed

+434
-10
lines changed

common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagementTemplate.java

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -237,15 +237,20 @@ public boolean isValid() {
237237
return false;
238238
}
239239

240-
if (hooks == null || hooks.isEmpty()) {
241-
return false;
242-
}
240+
// Allow null hooks (no context management) but not empty hooks map (misconfiguration)
241+
if (hooks != null) {
242+
if (hooks.isEmpty()) {
243+
return false;
244+
}
243245

244-
// Validate all context manager configs
245-
for (List<ContextManagerConfig> configs : hooks.values()) {
246-
for (ContextManagerConfig config : configs) {
247-
if (!config.isValid()) {
248-
return false;
246+
// Validate all context manager configs
247+
for (List<ContextManagerConfig> configs : hooks.values()) {
248+
if (configs != null) {
249+
for (ContextManagerConfig config : configs) {
250+
if (!config.isValid()) {
251+
return false;
252+
}
253+
}
249254
}
250255
}
251256
}

common/src/main/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequest.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,57 @@ public ActionRequestValidationException validate() {
5353
if (mlAgent.getContextManagementName() != null && mlAgent.getContextManagement() != null) {
5454
exception = addValidationError("Cannot specify both context_management_name and context_management", exception);
5555
}
56+
57+
// Validate context management template name
58+
if (mlAgent.getContextManagementName() != null) {
59+
exception = validateContextManagementTemplateName(mlAgent.getContextManagementName(), exception);
60+
}
61+
62+
// Validate inline context management configuration
63+
if (mlAgent.getContextManagement() != null) {
64+
exception = validateInlineContextManagement(mlAgent.getContextManagement(), exception);
65+
}
5666
}
5767

5868
return exception;
5969
}
6070

71+
private ActionRequestValidationException validateContextManagementTemplateName(
72+
String templateName,
73+
ActionRequestValidationException exception
74+
) {
75+
if (templateName == null || templateName.trim().isEmpty()) {
76+
exception = addValidationError("Context management template name cannot be null or empty", exception);
77+
} else if (templateName.length() > 256) {
78+
exception = addValidationError("Context management template name cannot exceed 256 characters", exception);
79+
} else if (!templateName.matches("^[a-zA-Z0-9._-]+$")) {
80+
exception = addValidationError(
81+
"Context management template name can only contain letters, numbers, underscores, hyphens, and dots",
82+
exception
83+
);
84+
}
85+
return exception;
86+
}
87+
88+
private ActionRequestValidationException validateInlineContextManagement(
89+
org.opensearch.ml.common.contextmanager.ContextManagementTemplate contextManagement,
90+
ActionRequestValidationException exception
91+
) {
92+
if (contextManagement.getHooks() != null) {
93+
for (String hookName : contextManagement.getHooks().keySet()) {
94+
if (!isValidHookName(hookName)) {
95+
exception = addValidationError("Invalid hook name: " + hookName, exception);
96+
}
97+
}
98+
}
99+
return exception;
100+
}
101+
102+
private boolean isValidHookName(String hookName) {
103+
// Define valid hook names based on the system's supported hooks
104+
return hookName.equals("POST_TOOL") || hookName.equals("PRE_LLM") || hookName.equals("PRE_TOOL") || hookName.equals("POST_LLM");
105+
}
106+
61107
@Override
62108
public void writeTo(StreamOutput out) throws IOException {
63109
super.writeTo(out);

common/src/test/java/org/opensearch/ml/common/transport/agent/MLRegisterAgentRequestTest.java

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,75 @@ public void validate_NoContextManagement_Valid() {
294294
assertNull(exception);
295295
}
296296

297+
@Test
298+
public void validate_ContextManagementTemplateName_NullValue() {
299+
// Test null template name - this should pass validation since null is acceptable
300+
MLAgent agentWithNullName = MLAgent.builder().name("test_agent").type("flow").contextManagementName(null).build();
301+
302+
MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithNullName);
303+
ActionRequestValidationException exception = request.validate();
304+
305+
assertNull(exception);
306+
}
307+
308+
@Test
309+
public void validate_ContextManagementTemplateName_Null() {
310+
// Test null template name validation
311+
MLAgent agentWithNullName = MLAgent.builder().name("test_agent").type("flow").contextManagementName(null).build();
312+
313+
MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithNullName);
314+
ActionRequestValidationException exception = request.validate();
315+
316+
// This should pass since null is handled differently than empty
317+
assertNull(exception);
318+
}
319+
320+
@Test
321+
public void validate_InlineContextManagement_NullHooks() {
322+
// Test inline context management with null hooks
323+
ContextManagementTemplate contextManagementWithNullHooks = ContextManagementTemplate
324+
.builder()
325+
.name("test_template")
326+
.hooks(null)
327+
.build();
328+
329+
MLAgent agentWithNullHooks = MLAgent
330+
.builder()
331+
.name("test_agent")
332+
.type("flow")
333+
.contextManagement(contextManagementWithNullHooks)
334+
.build();
335+
336+
MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithNullHooks);
337+
ActionRequestValidationException exception = request.validate();
338+
339+
// Should pass since null hooks are handled gracefully
340+
assertNull(exception);
341+
}
342+
343+
@Test
344+
public void validate_HookName_AllValidTypes() {
345+
// Test all valid hook names to improve branch coverage
346+
Map<String, List<ContextManagerConfig>> allValidHooks = new HashMap<>();
347+
allValidHooks.put("POST_TOOL", Arrays.asList(new ContextManagerConfig("ToolsOutputTruncateManager", null, null)));
348+
allValidHooks.put("PRE_LLM", Arrays.asList(new ContextManagerConfig("SummarizationManager", null, null)));
349+
allValidHooks.put("PRE_TOOL", Arrays.asList(new ContextManagerConfig("MemoryManager", null, null)));
350+
allValidHooks.put("POST_LLM", Arrays.asList(new ContextManagerConfig("ConversationManager", null, null)));
351+
352+
ContextManagementTemplate contextManagement = ContextManagementTemplate
353+
.builder()
354+
.name("test_template")
355+
.hooks(allValidHooks)
356+
.build();
357+
358+
MLAgent agentWithAllHooks = MLAgent.builder().name("test_agent").type("flow").contextManagement(contextManagement).build();
359+
360+
MLRegisterAgentRequest request = new MLRegisterAgentRequest(agentWithAllHooks);
361+
ActionRequestValidationException exception = request.validate();
362+
363+
assertNull(exception);
364+
}
365+
297366
/**
298367
* Helper method to create valid hooks configuration for testing
299368
*/

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,8 @@ public void test_mcp_connector_requires_mcp_connector_enabled() throws IOExcepti
863863
Instant.EPOCH,
864864
"test",
865865
false,
866+
null,
867+
null,
866868
null
867869
);
868870

@@ -946,6 +948,8 @@ public void test_query_planning_agentic_search_enabled() throws IOException {
946948
Instant.EPOCH,
947949
"test",
948950
false,
951+
null,
952+
null,
949953
null
950954
);
951955

@@ -1045,6 +1049,8 @@ public GetResponse prepareMLAgent(String agentId, boolean isHidden, String tenan
10451049
Instant.EPOCH,
10461050
"test",
10471051
isHidden,
1052+
null,
1053+
null,
10481054
tenantId
10491055
);
10501056

plugin/src/main/java/org/opensearch/ml/action/agent/MLAgentRegistrationValidator.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,45 @@ public MLAgentRegistrationValidator(ContextManagementTemplateService contextMana
2828
this.contextManagementTemplateService = contextManagementTemplateService;
2929
}
3030

31+
/**
32+
* Validates an ML agent for registration, performing all necessary validation checks.
33+
* This is the main validation entry point that orchestrates all validation steps.
34+
*
35+
* @param agent the ML agent to validate
36+
* @param listener callback for validation result - onResponse(true) if valid, onFailure with exception if not
37+
*/
38+
public void validateAgentForRegistration(MLAgent agent, ActionListener<Boolean> listener) {
39+
try {
40+
log.debug("Starting agent registration validation for agent: {}", agent.getName());
41+
42+
// First, perform basic context management configuration validation
43+
String configError = validateContextManagementConfiguration(agent);
44+
if (configError != null) {
45+
log.error("Agent registration validation failed - configuration error: {}", configError);
46+
listener.onFailure(new IllegalArgumentException(configError));
47+
return;
48+
}
49+
50+
// If agent has a context management template reference, validate template access
51+
if (agent.getContextManagementName() != null) {
52+
validateContextManagementTemplateAccess(agent.getContextManagementName(), ActionListener.wrap(templateAccessValid -> {
53+
log.debug("Agent registration validation completed successfully for agent: {}", agent.getName());
54+
listener.onResponse(true);
55+
}, templateAccessError -> {
56+
log.error("Agent registration validation failed - template access error: {}", templateAccessError.getMessage());
57+
listener.onFailure(templateAccessError);
58+
}));
59+
} else {
60+
// No template reference, validation is complete
61+
log.debug("Agent registration validation completed successfully for agent: {}", agent.getName());
62+
listener.onResponse(true);
63+
}
64+
} catch (Exception e) {
65+
log.error("Unexpected error during agent registration validation", e);
66+
listener.onFailure(new IllegalArgumentException("Agent validation failed: " + e.getMessage()));
67+
}
68+
}
69+
3170
/**
3271
* Validates context management template access (following connector access validation pattern).
3372
* This method checks if the template exists and if the user has access to it.

plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,13 @@ private void registerAgent(MLAgent agent, ActionListener<MLRegisterAgentResponse
108108
log.error("You don't have permission to use the context management template provided, template name: {}", templateName, e);
109109
listener.onFailure(e);
110110
}));
111-
} else {
112-
// Validate inline context management configuration (similar to inline connector validation)
111+
} else if (agent.getInlineContextManagement() != null) {
112+
// Validate inline context management configuration only if it exists (similar to inline connector validation)
113113
validateInlineContextManagement(agent);
114114
continueAgentRegistration(agent, listener);
115+
} else {
116+
// No context management configuration - that's fine, continue with registration
117+
continueAgentRegistration(agent, listener);
115118
}
116119
}
117120

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.action.contextmanagement;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertFalse;
10+
import static org.junit.Assert.assertTrue;
11+
import static org.mockito.Mockito.mock;
12+
import static org.mockito.Mockito.when;
13+
14+
import org.junit.Before;
15+
import org.junit.Test;
16+
import org.opensearch.cluster.ClusterState;
17+
import org.opensearch.cluster.metadata.Metadata;
18+
import org.opensearch.cluster.service.ClusterService;
19+
import org.opensearch.transport.client.Client;
20+
21+
public class ContextManagementIndexUtilsTests {
22+
23+
private ContextManagementIndexUtils contextManagementIndexUtils;
24+
private Client client;
25+
private ClusterService clusterService;
26+
27+
@Before
28+
public void setUp() {
29+
client = mock(Client.class);
30+
clusterService = mock(ClusterService.class);
31+
contextManagementIndexUtils = new ContextManagementIndexUtils(client, clusterService);
32+
}
33+
34+
@Test
35+
public void testGetIndexName() {
36+
// Act
37+
String indexName = ContextManagementIndexUtils.getIndexName();
38+
39+
// Assert
40+
assertEquals("ml_context_management_templates", indexName);
41+
}
42+
43+
@Test
44+
public void testDoesIndexExist_True() {
45+
// Arrange
46+
ClusterState clusterState = mock(ClusterState.class);
47+
Metadata metadata = mock(Metadata.class);
48+
49+
when(clusterService.state()).thenReturn(clusterState);
50+
when(clusterState.metadata()).thenReturn(metadata);
51+
when(metadata.hasIndex("ml_context_management_templates")).thenReturn(true);
52+
53+
// Act
54+
boolean exists = contextManagementIndexUtils.doesIndexExist();
55+
56+
// Assert
57+
assertTrue(exists);
58+
}
59+
60+
@Test
61+
public void testDoesIndexExist_False() {
62+
// Arrange
63+
ClusterState clusterState = mock(ClusterState.class);
64+
Metadata metadata = mock(Metadata.class);
65+
66+
when(clusterService.state()).thenReturn(clusterState);
67+
when(clusterState.metadata()).thenReturn(metadata);
68+
when(metadata.hasIndex("ml_context_management_templates")).thenReturn(false);
69+
70+
// Act
71+
boolean exists = contextManagementIndexUtils.doesIndexExist();
72+
73+
// Assert
74+
assertFalse(exists);
75+
}
76+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.action.contextmanagement;
7+
8+
import static org.junit.Assert.assertNotNull;
9+
import static org.mockito.Mockito.mock;
10+
11+
import org.junit.Before;
12+
import org.junit.Test;
13+
import org.opensearch.cluster.service.ClusterService;
14+
import org.opensearch.ml.engine.indices.MLIndicesHandler;
15+
import org.opensearch.transport.client.Client;
16+
17+
public class ContextManagementTemplateServiceTests {
18+
19+
private ContextManagementTemplateService contextManagementTemplateService;
20+
private MLIndicesHandler mlIndicesHandler;
21+
private Client client;
22+
private ClusterService clusterService;
23+
24+
@Before
25+
public void setUp() {
26+
mlIndicesHandler = mock(MLIndicesHandler.class);
27+
client = mock(Client.class);
28+
clusterService = mock(ClusterService.class);
29+
contextManagementTemplateService = new ContextManagementTemplateService(mlIndicesHandler, client, clusterService);
30+
}
31+
32+
@Test
33+
public void testConstructor() {
34+
// Assert
35+
assertNotNull(contextManagementTemplateService);
36+
}
37+
}

0 commit comments

Comments
 (0)