From 4533d65bbf3cdd89e382185e73447576519dd524 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Tue, 8 Oct 2024 12:52:35 +0200 Subject: [PATCH] Add builder pattern and order parameter to advisors - Introduce builder pattern for MessageChatMemoryAdvisor, PromptChatMemoryAdvisor, QuestionAnswerAdvisor, SafeGuardAdvisor, and VectorStoreChatMemoryAdvisor. - Add 'order' parameter to control advisor execution priority - Modify constructors to include the new 'order' parameter - Update AbstractChatMemoryAdvisor to support the new 'order' parameter --- .../advisor/AbstractChatMemoryAdvisor.java | 56 +++++++++++++++++-- .../advisor/MessageChatMemoryAdvisor.java | 25 ++++++++- .../advisor/PromptChatMemoryAdvisor.java | 33 ++++++++++- .../client/advisor/QuestionAnswerAdvisor.java | 37 +++++++++++- .../advisor/VectorStoreChatMemoryAdvisor.java | 34 ++++++++++- .../modules/ROOT/pages/api/chatclient.adoc | 8 +-- 6 files changed, 179 insertions(+), 14 deletions(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java index 1ff37a137b9..45f4bf8a705 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/AbstractChatMemoryAdvisor.java @@ -26,6 +26,7 @@ import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.util.Assert; +import org.stringtemplate.v4.compiler.CodeGenerator.includeExpr_return; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -56,12 +57,20 @@ public abstract class AbstractChatMemoryAdvisor implements CallAroundAdvisor, private final boolean protectFromBlocking; + private final int order; + protected AbstractChatMemoryAdvisor(T chatMemory) { this(chatMemory, DEFAULT_CHAT_MEMORY_CONVERSATION_ID, DEFAULT_CHAT_MEMORY_RESPONSE_SIZE, true); } protected AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int defaultChatMemoryRetrieveSize, boolean protectFromBlocking) { + this(chatMemory, defaultConversationId, defaultChatMemoryRetrieveSize, protectFromBlocking, + Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER); + } + + protected AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, int defaultChatMemoryRetrieveSize, + boolean protectFromBlocking, int order) { Assert.notNull(chatMemory, "The chatMemory must not be null!"); Assert.hasText(defaultConversationId, "The conversationId must not be empty!"); @@ -71,6 +80,7 @@ protected AbstractChatMemoryAdvisor(T chatMemory, String defaultConversationId, this.defaultConversationId = defaultConversationId; this.defaultChatMemoryRetrieveSize = defaultChatMemoryRetrieveSize; this.protectFromBlocking = protectFromBlocking; + this.order = order; } @Override @@ -80,11 +90,11 @@ public String getName() { @Override public int getOrder() { - // The (Ordered.HIGHEST_PRECEDENCE + 1000) value ensures this order has lower - // priority (e.g. precedences) than the internal Spring AI advisors. It leaves - // room (1000 slots) for the user to plug in their own advisors with higher + // by default the (Ordered.HIGHEST_PRECEDENCE + 1000) value ensures this order has + // lower priority (e.g. precedences) than the internal Spring AI advisors. It + // leaves room (1000 slots) for the user to plug in their own advisors with higher // priority. - return Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER; + return this.order; } protected T getChatMemoryStore() { @@ -118,5 +128,43 @@ protected Flux doNextWithProtectFromBlockingBefore(AdvisedReque : chain.nextAroundStream(beforeAdvise.apply(advisedRequest)); } + public static abstract class AbstractBuilder { + + protected String conversationId = DEFAULT_CHAT_MEMORY_CONVERSATION_ID; + + protected int chatMemoryRetrieveSize = DEFAULT_CHAT_MEMORY_RESPONSE_SIZE; + + protected boolean protectFromBlocking = true; + + protected int order = Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER; + + protected T chatMemory; + + protected AbstractBuilder(T chatMemory) { + this.chatMemory = chatMemory; + } + + public AbstractBuilder withConversationId(String conversationId) { + this.conversationId = conversationId; + return this; + } + + public AbstractBuilder withChatMemoryRetrieveSize(int chatMemoryRetrieveSize) { + this.chatMemoryRetrieveSize = chatMemoryRetrieveSize; + return this; + } + + public AbstractBuilder withProtectFromBlocking(boolean protectFromBlocking) { + this.protectFromBlocking = protectFromBlocking; + return this; + } + + public AbstractBuilder withOrder(int order) { + this.order = order; + return this; + } + + abstract public AbstractChatMemoryAdvisor build(); + } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java index 66d4ee34755..0677e28ad19 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java @@ -21,6 +21,7 @@ import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.memory.ChatMemory; @@ -43,7 +44,12 @@ public MessageChatMemoryAdvisor(ChatMemory chatMemory) { } public MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize) { - super(chatMemory, defaultConversationId, chatHistoryWindowSize, true); + this(chatMemory, defaultConversationId, chatHistoryWindowSize, Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER); + } + + public MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize, + int order) { + super(chatMemory, defaultConversationId, chatHistoryWindowSize, true, order); } @Override @@ -101,4 +107,21 @@ private void observeAfter(AdvisedResponse advisedResponse) { this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages); } + public static Builder builder(ChatMemory chatMemory) { + return new Builder(chatMemory); + } + + public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder { + + protected Builder(ChatMemory chatMemory) { + super(chatMemory); + } + + public MessageChatMemoryAdvisor build() { + return new MessageChatMemoryAdvisor(this.chatMemory, this.conversationId, this.chatMemoryRetrieveSize, + this.order); + } + + } + } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java index 961bca0f32b..d183ab31697 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java @@ -23,6 +23,7 @@ import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.memory.ChatMemory; @@ -66,7 +67,13 @@ public PromptChatMemoryAdvisor(ChatMemory chatMemory, String systemTextAdvise) { public PromptChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize, String systemTextAdvise) { - super(chatMemory, defaultConversationId, chatHistoryWindowSize, true); + this(chatMemory, defaultConversationId, chatHistoryWindowSize, systemTextAdvise, + Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER); + } + + public PromptChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize, + String systemTextAdvise, int order) { + super(chatMemory, defaultConversationId, chatHistoryWindowSize, true, order); this.systemTextAdvise = systemTextAdvise; } @@ -133,4 +140,28 @@ private void observeAfter(AdvisedResponse advisedResponse) { this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages); } + public static Builder builder(ChatMemory chatMemory) { + return new Builder(chatMemory); + } + + public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder { + + private String systemTextAdvise = DEFAULT_SYSTEM_TEXT_ADVISE; + + protected Builder(ChatMemory chatMemory) { + super(chatMemory); + } + + public Builder withSystemTextAdvise(String systemTextAdvise) { + this.systemTextAdvise = systemTextAdvise; + return this; + } + + public PromptChatMemoryAdvisor build() { + return new PromptChatMemoryAdvisor(this.chatMemory, this.conversationId, this.chatMemoryRetrieveSize, + this.systemTextAdvise, this.order); + } + + } + } \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java index c00f8e0bfb9..788c47a14b9 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/QuestionAnswerAdvisor.java @@ -24,8 +24,8 @@ import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; -import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor; +import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.model.ChatResponse; @@ -61,6 +61,8 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv the user that you can't answer the question. """; + private static final int DEFAULT_ORDER = 0; + private final VectorStore vectorStore; private final String userTextAdvise; @@ -73,6 +75,8 @@ public class QuestionAnswerAdvisor implements CallAroundAdvisor, StreamAroundAdv private final boolean protectFromBlocking; + private final int order; + /** * The QuestionAnswerAdvisor retrieves context information from a Vector Store and * combines it with the user's text. @@ -121,6 +125,25 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques */ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise, boolean protectFromBlocking) { + this(vectorStore, searchRequest, userTextAdvise, protectFromBlocking, DEFAULT_ORDER); + } + + /** + * The QuestionAnswerAdvisor retrieves context information from a Vector Store and + * combines it with the user's text. + * @param vectorStore The vector store to use + * @param searchRequest The search request defined using the portable filter + * expression syntax + * @param userTextAdvise the user text to append to the existing user prompt. The text + * should contain a placeholder named "question_answer_context". + * @param protectFromBlocking if true the advisor will protect the execution from + * blocking threads. If false the advisor will not protect the execution from blocking + * threads. This is useful when the advisor is used in a non-blocking environment. It + * is true by default. + * @param order the order of the advisor. + */ + public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchRequest, String userTextAdvise, + boolean protectFromBlocking, int order) { Assert.notNull(vectorStore, "The vectorStore must not be null!"); Assert.notNull(searchRequest, "The searchRequest must not be null!"); @@ -130,6 +153,7 @@ public QuestionAnswerAdvisor(VectorStore vectorStore, SearchRequest searchReques this.searchRequest = searchRequest; this.userTextAdvise = userTextAdvise; this.protectFromBlocking = protectFromBlocking; + this.order = order; } @Override @@ -139,7 +163,7 @@ public String getName() { @Override public int getOrder() { - return 0; + return this.order; } @Override @@ -249,6 +273,8 @@ public static class Builder { private boolean protectFromBlocking = true; + private int order = DEFAULT_ORDER; + private Builder(VectorStore vectorStore) { Assert.notNull(vectorStore, "The vectorStore must not be null!"); this.vectorStore = vectorStore; @@ -271,9 +297,14 @@ public Builder withProtectFromBlocking(boolean protectFromBlocking) { return this; } + public Builder withOrder(int order) { + this.order = order; + return this; + } + public QuestionAnswerAdvisor build() { return new QuestionAnswerAdvisor(this.vectorStore, this.searchRequest, this.userTextAdvise, - this.protectFromBlocking); + this.protectFromBlocking, this.order); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java index b6ab9439877..8e7da1ad02f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/VectorStoreChatMemoryAdvisor.java @@ -23,6 +23,7 @@ import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; +import org.springframework.ai.chat.client.advisor.api.Advisor; import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAroundAdvisorChain; import org.springframework.ai.chat.messages.AssistantMessage; @@ -78,7 +79,13 @@ public VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConve public VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId, int chatHistoryWindowSize, String systemTextAdvise) { - super(vectorStore, defaultConversationId, chatHistoryWindowSize, true); + this(vectorStore, defaultConversationId, chatHistoryWindowSize, systemTextAdvise, + Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER); + } + + public VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId, + int chatHistoryWindowSize, String systemTextAdvise, int order) { + super(vectorStore, defaultConversationId, chatHistoryWindowSize, true, order); this.systemTextAdvise = systemTextAdvise; } @@ -168,4 +175,29 @@ else if (message instanceof AssistantMessage assistantMessage) { return docs; } + public static Builder builder(VectorStore chatMemory) { + return new Builder(chatMemory); + } + + public static class Builder extends AbstractChatMemoryAdvisor.AbstractBuilder { + + private String systemTextAdvise = DEFAULT_SYSTEM_TEXT_ADVISE; + + protected Builder(VectorStore chatMemory) { + super(chatMemory); + } + + public Builder withSystemTextAdvise(String systemTextAdvise) { + this.systemTextAdvise = systemTextAdvise; + return this; + } + + @Override + public VectorStoreChatMemoryAdvisor build() { + return new VectorStoreChatMemoryAdvisor(this.chatMemory, this.conversationId, this.chatMemoryRetrieveSize, + this.systemTextAdvise); + } + + } + } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc index 89025131c8e..18efb365583 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc @@ -423,7 +423,8 @@ The following advisor implementations use the `ChatMemory` interface to advice t * `MessageChatMemoryAdvisor` : Memory is retrieved and added as a collection of messages to the prompt * `PromptChatMemoryAdvisor` : Memory is retrieved and added into the prompt's system text. -* `VectorStoreChatMemoryAdvisor` : The constructor `VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId, int chatHistoryWindowSize)` lets you specify the VectorStore to retrieve the chat history from, the unique conversation ID, the size of the chat history to be retrieved in token size. +* `VectorStoreChatMemoryAdvisor` : The constructor `VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId, int chatHistoryWindowSize, int order)` lets you specify the VectorStore to retrieve the chat history from, the unique conversation ID, the size of the chat history to be retrieved in token size. +The VectorStoreChatMemoryAdvisor.builder() method lets you specify the default conversation ID, the chat history window size, and the order of the chat history to be retrieved. A sample `@Service` implementation that uses several advisors is shown below. @@ -452,10 +453,9 @@ public class CustomerSupportAssistant { If there is a charge for the change, you MUST ask the user to consent before proceeding. """) .defaultAdvisors( - new PromptChatMemoryAdvisor(chatMemory), - // new MessageChatMemoryAdvisor(chatMemory), // CHAT MEMORY + new MessageChatMemoryAdvisor(chatMemory), // CHAT MEMORY new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults()), // RAG - new LoggingAdvisor()) + new SimpleLoggerAdvisor()) .defaultFunctions("getBookingDetails", "changeBooking", "cancelBooking") // FUNCTION CALLING .build(); }