Skip to content

Commit fed7e52

Browse files
authored
refactor memory interface; add agentic conversation memory (#4434)
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
1 parent f4ac35c commit fed7e52

File tree

25 files changed

+1369
-459
lines changed

25 files changed

+1369
-459
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common;
7+
8+
import java.util.Locale;
9+
10+
public enum MLMemoryType {
11+
CONVERSATION_INDEX,
12+
AGENTIC_MEMORY,
13+
REMOTE_AGENTIC_MEMORY;
14+
15+
public static MLMemoryType from(String value) {
16+
if (value != null) {
17+
try {
18+
return MLMemoryType.valueOf(value.toUpperCase(Locale.ROOT));
19+
} catch (Exception e) {
20+
throw new IllegalArgumentException("Wrong Memory type");
21+
}
22+
}
23+
return null;
24+
}
25+
}

common/src/main/java/org/opensearch/ml/common/agent/MLMemorySpec.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,32 +26,37 @@ public class MLMemorySpec implements ToXContentObject {
2626
public static final String MEMORY_TYPE_FIELD = "type";
2727
public static final String WINDOW_SIZE_FIELD = "window_size";
2828
public static final String SESSION_ID_FIELD = "session_id";
29+
public static final String MEMORY_CONTAINER_ID_FIELD = "memory_container_id";
2930

3031
private String type;
3132
@Setter
3233
private String sessionId;
3334
private Integer windowSize;
35+
private String memoryContainerId;
3436

3537
@Builder(toBuilder = true)
36-
public MLMemorySpec(String type, String sessionId, Integer windowSize) {
38+
public MLMemorySpec(String type, String sessionId, Integer windowSize, String memoryContainerId) {
3739
if (type == null) {
3840
throw new IllegalArgumentException("agent name is null");
3941
}
4042
this.type = type;
4143
this.sessionId = sessionId;
4244
this.windowSize = windowSize;
45+
this.memoryContainerId = memoryContainerId;
4346
}
4447

4548
public MLMemorySpec(StreamInput input) throws IOException {
4649
type = input.readString();
4750
sessionId = input.readOptionalString();
4851
windowSize = input.readOptionalInt();
52+
memoryContainerId = input.readOptionalString();
4953
}
5054

5155
public void writeTo(StreamOutput out) throws IOException {
5256
out.writeString(type);
5357
out.writeOptionalString(sessionId);
5458
out.writeOptionalInt(windowSize);
59+
out.writeOptionalString(memoryContainerId);
5560
}
5661

5762
@Override
@@ -64,6 +69,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
6469
if (sessionId != null) {
6570
builder.field(SESSION_ID_FIELD, sessionId);
6671
}
72+
if (memoryContainerId != null) {
73+
builder.field(MEMORY_CONTAINER_ID_FIELD, memoryContainerId);
74+
}
6775
builder.endObject();
6876
return builder;
6977
}
@@ -72,6 +80,7 @@ public static MLMemorySpec parse(XContentParser parser) throws IOException {
7280
String type = null;
7381
String sessionId = null;
7482
Integer windowSize = null;
83+
String memoryContainerId = null;
7584

7685
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
7786
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
@@ -88,12 +97,15 @@ public static MLMemorySpec parse(XContentParser parser) throws IOException {
8897
case WINDOW_SIZE_FIELD:
8998
windowSize = parser.intValue();
9099
break;
100+
case MEMORY_CONTAINER_ID_FIELD:
101+
memoryContainerId = parser.text();
102+
break;
91103
default:
92104
parser.skipChildren();
93105
break;
94106
}
95107
}
96-
return MLMemorySpec.builder().type(type).sessionId(sessionId).windowSize(windowSize).build();
108+
return MLMemorySpec.builder().type(type).sessionId(sessionId).windowSize(windowSize).memoryContainerId(memoryContainerId).build();
97109
}
98110

99111
public static MLMemorySpec fromStream(StreamInput in) throws IOException {

common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.opensearch.core.xcontent.ToXContentObject;
2929
import org.opensearch.core.xcontent.XContentBuilder;
3030
import org.opensearch.ml.common.CommonValue;
31+
import org.opensearch.ml.common.memory.Message;
3132
import org.opensearch.search.SearchHit;
3233

3334
import lombok.AllArgsConstructor;
@@ -39,7 +40,7 @@
3940
*/
4041
@Builder
4142
@AllArgsConstructor
42-
public class Interaction implements Writeable, ToXContentObject {
43+
public class Interaction implements Writeable, ToXContentObject, Message {
4344

4445
@Getter
4546
private String id;
@@ -275,4 +276,13 @@ public String toString() {
275276
+ "}";
276277
}
277278

279+
@Override
280+
public String getType() {
281+
return "";
282+
}
283+
284+
@Override
285+
public String getContent() {
286+
return "";
287+
}
278288
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.memory;
7+
8+
import java.util.List;
9+
import java.util.Map;
10+
11+
import org.opensearch.core.action.ActionListener;
12+
13+
/**
14+
* A general memory interface.
15+
* @param <T> Message type
16+
* @param <R> Save response type
17+
* @param <S> Update response type
18+
*/
19+
public interface Memory<T extends Message, R, S> {
20+
21+
/**
22+
* Get memory type.
23+
* @return memory type
24+
*/
25+
String getType();
26+
27+
/**
28+
* Get memory ID.
29+
* @return memory ID
30+
*/
31+
String getId();
32+
33+
default void save(Message message, String parentId, Integer traceNum, String action) {}
34+
35+
default void save(Message message, String parentId, Integer traceNum, String action, ActionListener<R> listener) {}
36+
37+
default void update(String messageId, Map<String, Object> updateContent, ActionListener<S> updateListener) {}
38+
39+
default void getMessages(int size, ActionListener<List<T>> listener) {}
40+
41+
/**
42+
* Clear all memory.
43+
*/
44+
void clear();
45+
46+
void deleteInteractionAndTrace(String regenerateInteractionId, ActionListener<Boolean> wrap);
47+
48+
interface Factory<M extends Memory> {
49+
/**
50+
* Create an instance of this Memory.
51+
*
52+
* @param params Parameters for the memory
53+
* @param listener Action listener for the memory creation action
54+
*/
55+
void create(Map<String, Object> params, ActionListener<M> listener);
56+
}
57+
}

spi/src/main/java/org/opensearch/ml/common/spi/memory/Message.java renamed to common/src/main/java/org/opensearch/ml/common/memory/Message.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.common.spi.memory;
6+
package org.opensearch.ml.common.memory;
77

88
/**
99
* General message interface.
@@ -12,13 +12,13 @@ public interface Message {
1212

1313
/**
1414
* Get message type.
15-
* @return
15+
* @return message type
1616
*/
1717
String getType();
1818

1919
/**
2020
* Get message content.
21-
* @return
21+
* @return message content
2222
*/
2323
String getContent();
2424
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import org.opensearch.ml.common.hooks.EnhancedPostToolEvent;
1818
import org.opensearch.ml.common.hooks.HookRegistry;
1919
import org.opensearch.ml.common.hooks.PreLLMEvent;
20-
import org.opensearch.ml.common.spi.memory.Memory;
20+
import org.opensearch.ml.common.memory.Memory;
2121
import org.opensearch.ml.common.utils.StringUtils;
2222
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
2323

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import static org.opensearch.ml.common.CommonValue.MCP_CONNECTORS_FIELD;
1111
import static org.opensearch.ml.common.CommonValue.MCP_CONNECTOR_ID_FIELD;
1212
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
13+
import static org.opensearch.ml.common.agent.MLMemorySpec.MEMORY_CONTAINER_ID_FIELD;
1314
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
1415
import static org.opensearch.ml.common.utils.StringUtils.gson;
1516
import static org.opensearch.ml.common.utils.StringUtils.isJson;
@@ -29,6 +30,7 @@
2930
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_DESCRIPTIONS;
3031
import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.TOOL_NAMES;
3132
import static org.opensearch.ml.engine.algorithms.agent.MLPlanExecuteAndReflectAgentRunner.RESPONSE_FIELD;
33+
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.APP_TYPE;
3234
import static org.opensearch.ml.engine.memory.ConversationIndexMemory.LAST_N_INTERACTIONS;
3335

3436
import java.io.IOException;
@@ -83,6 +85,7 @@
8385
import org.opensearch.ml.engine.algorithms.remote.McpStreamableHttpConnectorExecutor;
8486
import org.opensearch.ml.engine.encryptor.Encryptor;
8587
import org.opensearch.ml.engine.function_calling.FunctionCalling;
88+
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
8689
import org.opensearch.ml.engine.tools.McpSseTool;
8790
import org.opensearch.ml.engine.tools.McpStreamableHttpTool;
8891
import org.opensearch.remote.metadata.client.GetDataObjectRequest;
@@ -1014,4 +1017,22 @@ public static Tool createTool(Map<String, Tool.Factory> toolFactories, Map<Strin
10141017

10151018
return tool;
10161019
}
1020+
1021+
public static Map<String, Object> createMemoryParams(
1022+
String question,
1023+
String memoryId,
1024+
String appType,
1025+
MLAgent mlAgent,
1026+
String memoryContainerId
1027+
) {
1028+
Map<String, Object> memoryParams = new HashMap<>();
1029+
memoryParams.put(ConversationIndexMemory.MEMORY_NAME, question);
1030+
memoryParams.put(ConversationIndexMemory.MEMORY_ID, memoryId);
1031+
memoryParams.put(APP_TYPE, appType);
1032+
if (mlAgent.getMemory().getMemoryContainerId() != null) {
1033+
memoryParams.put(MEMORY_CONTAINER_ID_FIELD, mlAgent.getMemory().getMemoryContainerId());
1034+
}
1035+
memoryParams.putIfAbsent(MEMORY_CONTAINER_ID_FIELD, memoryContainerId);
1036+
return memoryParams;
1037+
}
10171038
}

0 commit comments

Comments
 (0)