Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
Expand Down Expand Up @@ -413,8 +414,15 @@ else if (message.getMessageType() == MessageType.TOOL) {
systemPrompt, this.defaultOptions.getMaxTokens(), this.defaultOptions.getTemperature(), stream);

if (prompt.getOptions() != null) {
AnthropicChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
ChatOptions.class, AnthropicChatOptions.class);
AnthropicChatOptions updatedRuntimeOptions;
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
FunctionCallingOptions.class, AnthropicChatOptions.class);
}
else {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
AnthropicChatOptions.class);
}

functionsForThisRequest.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;

import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage;
import org.springframework.ai.chat.messages.AssistantMessage;
Expand All @@ -47,6 +48,7 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

Expand Down Expand Up @@ -286,8 +288,15 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {
functionsForThisRequest.addAll(this.defaultOptions.getFunctions());

if (prompt.getOptions() != null) {
AzureOpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
ChatOptions.class, AzureOpenAiChatOptions.class);
AzureOpenAiChatOptions updatedRuntimeOptions;
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
FunctionCallingOptions.class, AzureOpenAiChatOptions.class);
}
else {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
AzureOpenAiChatOptions.class);
}
options = this.merge(updatedRuntimeOptions, options);

functionsForThisRequest.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
Expand Down Expand Up @@ -391,8 +392,16 @@ else if (message.getMessageType() == MessageType.TOOL) {
Set<String> enabledToolsToUse = new HashSet<>();

if (prompt.getOptions() != null) {
MiniMaxChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
ChatOptions.class, MiniMaxChatOptions.class);
MiniMaxChatOptions updatedRuntimeOptions;

if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
FunctionCallingOptions.class, MiniMaxChatOptions.class);
}
else {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
MiniMaxChatOptions.class);
}

enabledToolsToUse.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
Expand Down Expand Up @@ -367,8 +368,16 @@ else if (message instanceof ToolResponseMessage toolResponseMessage) {
request = ModelOptionsUtils.merge(request, this.defaultOptions, MistralAiApi.ChatCompletionRequest.class);

if (prompt.getOptions() != null) {
var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
MistralAiChatOptions.class);
MistralAiChatOptions updatedRuntimeOptions;

if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
FunctionCallingOptions.class, MistralAiChatOptions.class);
}
else {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
MistralAiChatOptions.class);
}

functionsForThisRequest.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.moonshot.api.MoonshotApi;
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletion;
import org.springframework.ai.moonshot.api.MoonshotApi.ChatCompletion.Choice;
Expand Down Expand Up @@ -341,9 +342,16 @@ else if (message.getMessageType() == MessageType.TOOL) {
Set<String> enabledToolsToUse = new HashSet<>();

if (prompt.getOptions() != null) {
MoonshotChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
ChatOptions.class, MoonshotChatOptions.class);
MoonshotChatOptions updatedRuntimeOptions;

if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
FunctionCallingOptions.class, MoonshotChatOptions.class);
}
else {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
MoonshotChatOptions.class);
}
enabledToolsToUse.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));

request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaApi.ChatRequest;
import org.springframework.ai.ollama.api.OllamaApi.Message.Role;
Expand Down Expand Up @@ -297,8 +298,14 @@ else if (message instanceof ToolResponseMessage toolMessage) {
// runtime options
OllamaOptions runtimeOptions = null;
if (prompt.getOptions() != null) {
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
OllamaOptions.class);
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class,
OllamaOptions.class);
}
else {
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
OllamaOptions.class);
}
functionsForThisRequest.addAll(this.runtimeFunctionCallbackConfigurations(runtimeOptions));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice;
Expand Down Expand Up @@ -477,8 +478,16 @@ else if (message.getMessageType() == MessageType.TOOL) {
Set<String> enabledToolsToUse = new HashSet<>();

if (prompt.getOptions() != null) {
OpenAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
ChatOptions.class, OpenAiChatOptions.class);
OpenAiChatOptions updatedRuntimeOptions = null;

if (prompt.getOptions() instanceof FunctionCallingOptions) {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(((FunctionCallingOptions) prompt.getOptions()),
FunctionCallingOptions.class, OpenAiChatOptions.class);
}
else if (prompt.getOptions() instanceof OpenAiChatOptions) {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
OpenAiChatOptions.class);
}

enabledToolsToUse.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,9 @@ public ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
}

if (prompt.getOptions() != null) {
if (prompt.getOptions() != null) {
var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
QianFanChatOptions.class);
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ChatOptions: "
+ prompt.getOptions().getClass().getSimpleName());
}
var updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
QianFanChatOptions.class);
request = ModelOptionsUtils.merge(updatedRuntimeOptions, request, ChatCompletionRequest.class);
}
return request;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.lang.NonNull;
Expand Down Expand Up @@ -80,8 +81,6 @@
*/
public class VertexAiGeminiChatModel extends AbstractToolCallSupport implements ChatModel, DisposableBean {

private final static boolean IS_RUNTIME_CALL = true;

private final VertexAI vertexAI;

private final VertexAiGeminiChatOptions defaultOptions;
Expand Down Expand Up @@ -292,9 +291,15 @@ GeminiRequest createGeminiRequest(Prompt prompt) {
VertexAiGeminiChatOptions updatedRuntimeOptions = VertexAiGeminiChatOptions.builder().build();

if (prompt.getOptions() != null) {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
VertexAiGeminiChatOptions.class);
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
FunctionCallingOptions.class, VertexAiGeminiChatOptions.class);

}
else {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
VertexAiGeminiChatOptions.class);
}
functionsForThisRequest.addAll(runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.zhipuai.api.ZhiPuAiApi;
import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletion;
Expand Down Expand Up @@ -358,8 +359,15 @@ else if (message.getMessageType() == MessageType.TOOL) {
Set<String> enabledToolsToUse = new HashSet<>();

if (prompt.getOptions() != null) {
ZhiPuAiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(),
ChatOptions.class, ZhiPuAiChatOptions.class);
ZhiPuAiChatOptions updatedRuntimeOptions;
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
FunctionCallingOptions.class, ZhiPuAiChatOptions.class);
}
else {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
ZhiPuAiChatOptions.class);
}

enabledToolsToUse.addAll(this.runtimeFunctionCallbackConfigurations(updatedRuntimeOptions));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import java.util.List;
import java.util.Set;

import org.springframework.ai.chat.prompt.ChatOptions;

/**
* @author Christian Tzolov
*/
public interface FunctionCallingOptions {
public interface FunctionCallingOptions extends ChatOptions {

/**
* Function Callbacks to be registered with the ChatModel. For Prompt Options the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
Expand Down Expand Up @@ -79,6 +80,28 @@ void functionCallTest() {
});
}

@Test
void functionCallWithPortableFunctionCallingOptions() {

contextRunner
.withPropertyValues(
"spring.ai.anthropic.chat.options.model=" + AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue())
.run(context -> {

AnthropicChatModel chatModel = context.getBean(AnthropicChatModel.class);

var userMessage = new UserMessage(
"What's the weather like in San Francisco, in Paris, France and in Tokyo, Japan? Return the temperature in Celsius.");

ChatResponse response = chatModel.call(new Prompt(List.of(userMessage),
PortableFunctionCallingOptions.builder().withFunction("weatherFunction").build()));

logger.info("Response: {}", response);

assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15");
});
}

@Configuration
static class Config {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
Expand Down Expand Up @@ -80,6 +81,26 @@ void functionCallTest() {
});
}

@Test
void functionCallWithPortableFunctionCallingOptions() {
contextRunner.withPropertyValues("spring.ai.azure.openai.chat.options..deployment-name=" + getDeploymentName())
.run(context -> {

ChatModel chatModel = context.getBean(AzureOpenAiChatModel.class);

UserMessage userMessage = new UserMessage(
"What's the weather like in San Francisco, Paris and in Tokyo? Use Multi-turn function calling.");

ChatResponse response = chatModel.call(new Prompt(List.of(userMessage),
PortableFunctionCallingOptions.builder().withFunction("weatherFunction").build()));

logger.info("Response: {}", response);

assertThat(response.getResult().getOutput().getContent()).contains("30", "10", "15");

});
}

@Configuration
static class Config {

Expand Down
Loading