Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.anthropic.api.AnthropicApi.AnthropicMessage;
Expand Down Expand Up @@ -268,42 +267,37 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, previousChatResponse);
ChatResponse chatResponse = toChatResponse(chatCompletionResponse, accumulatedUsage);

if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) {

if (chatResponse.hasFinishReasons(Set.of("tool_use"))) {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.deferContextual(ctx -> {
// TODO: factor out the tool execution logic with setting context into a utility.
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
}
finally {
ToolCallReactiveContextHolder.clearContext();
}
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder().from(chatResponse)
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
.build());
}
else {
// Send the tool execution result back to the model.
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
chatResponse);
}
}).subscribeOn(Schedulers.boundedElastic());
}
else {
return Mono.empty();
}
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) {

if (chatResponse.hasFinishReasons(Set.of("tool_use"))) {
return Flux.deferContextual(ctx -> {
// TODO: factor out the tool execution logic with setting context into a utility.
ToolCallReactiveContextHolder.setContext(ctx);
return this.toolCallingManager.executeToolCallsAsync(prompt, chatResponse)
.doFinally(s -> ToolCallReactiveContextHolder.clearContext())
.flatMapMany(toolExecutionResult -> {
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder().from(chatResponse)
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
.build());
}
else {
// Send the tool execution result back to the model.
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
chatResponse);
}
});
});
}
else {
// If internal tool execution is not required, just return the chat response.
return Mono.just(chatResponse);
return Mono.empty();
}
}
else {
// If internal tool execution is not required, just return the chat response.
return Mono.just(chatResponse);
}
})
.doOnError(observation::error)
.doFinally(s -> observation.stop())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;

import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.JsonSchema;
import org.springframework.ai.azure.openai.AzureOpenAiResponseFormat.Type;
Expand Down Expand Up @@ -379,31 +378,27 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha

return chatResponseFlux.flatMapSequential(chatResponse -> {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)) {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.deferContextual(ctx -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
}
finally {
ToolCallReactiveContextHolder.clearContext();
}
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder()
.from(chatResponse)
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
.build());
}
else {
// Send the tool execution result back to the model.
return this.internalStream(
new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
chatResponse);
}
}).subscribeOn(Schedulers.boundedElastic());
ToolCallReactiveContextHolder.setContext(ctx);
return this.toolCallingManager.executeToolCallsAsync(prompt, chatResponse)
.doFinally(s -> ToolCallReactiveContextHolder.clearContext())
.flatMapMany(toolExecutionResult -> {
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the
// client.
return Flux.just(ChatResponse.builder()
.from(chatResponse)
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
.build());
}
else {
// Send the tool execution result back to the model.
return this.internalStream(
new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
chatResponse);
}
});
});
}

Flux<ChatResponse> flux = Flux.just(chatResponse)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.document.Document;
Expand Down Expand Up @@ -805,32 +804,27 @@ private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousCh
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse)
&& chatResponse.hasFinishReasons(Set.of(StopReason.TOOL_USE.toString()))) {

// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.deferContextual(ctx -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
}
finally {
ToolCallReactiveContextHolder.clearContext();
}

if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder()
.from(chatResponse)
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
.build());
}
else {
// Send the tool execution result back to the model.
return this.internalStream(
new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
chatResponse);
}
}).subscribeOn(Schedulers.boundedElastic());
ToolCallReactiveContextHolder.setContext(ctx);
return this.toolCallingManager.executeToolCallsAsync(prompt, chatResponse)
.doFinally(s -> ToolCallReactiveContextHolder.clearContext())
.flatMapMany(toolExecutionResult -> {
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the
// client.
return Flux.just(ChatResponse.builder()
.from(chatResponse)
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
.build());
}
else {
// Send the tool execution result back to the model.
return this.internalStream(
new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
chatResponse);
}
});
});
}
else {
return Flux.just(chatResponse);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
Expand Down Expand Up @@ -285,36 +284,31 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
}));

// @formatter:off
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.deferContextual(ctx -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
}
finally {
ToolCallReactiveContextHolder.clearContext();
}
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder().from(response)
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
.build());
}
else {
// Send the tool execution result back to the model.
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
response);
}
}).subscribeOn(Schedulers.boundedElastic());
}
else {
return Flux.just(response);
}
})
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
return Flux.deferContextual(ctx -> {
ToolCallReactiveContextHolder.setContext(ctx);
return this.toolCallingManager.executeToolCallsAsync(prompt, response)
.doFinally(s -> ToolCallReactiveContextHolder.clearContext())
.flatMapMany(toolExecutionResult -> {
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder().from(response)
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
.build());
}
else {
// Send the tool execution result back to the model.
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
response);
}
});
});
}
else {
return Flux.just(response);
}
})
.doOnError(observation::error)
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;

import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
Expand Down Expand Up @@ -548,35 +547,30 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
});

// @formatter:off
Flux<ChatResponse> flux = chatResponseFlux.flatMap(response -> {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
// FIXME: bounded elastic needs to be used since tool calling
// is currently only synchronous
return Flux.deferContextual(ctx -> {
ToolExecutionResult toolExecutionResult;
try {
ToolCallReactiveContextHolder.setContext(ctx);
toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
}
finally {
ToolCallReactiveContextHolder.clearContext();
}
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder().from(response)
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
.build());
}
else {
// Send the tool execution result back to the model.
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response);
}
}).subscribeOn(Schedulers.boundedElastic());
}
else {
return Flux.just(response);
}
})
Flux<ChatResponse> flux = chatResponseFlux.flatMap(response -> {
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
return Flux.deferContextual(ctx -> {
ToolCallReactiveContextHolder.setContext(ctx);
return this.toolCallingManager.executeToolCallsAsync(prompt, response)
.doFinally(s -> ToolCallReactiveContextHolder.clearContext())
.flatMapMany(toolExecutionResult -> {
if (toolExecutionResult.returnDirect()) {
// Return tool execution result directly to the client.
return Flux.just(ChatResponse.builder().from(response)
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
.build());
}
else {
// Send the tool execution result back to the model.
return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), response);
}
});
});
}
else {
return Flux.just(response);
}
})
.doOnError(observation::error)
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.List;

import com.fasterxml.jackson.databind.node.ObjectNode;
import reactor.core.publisher.Mono;

import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
Expand Down Expand Up @@ -96,4 +97,17 @@ public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResp
return this.delegateToolCallingManager.executeToolCalls(prompt, chatResponse);
}

/**
* Executes tool calls asynchronously by delegating to the underlying tool calling
* manager.
* @param prompt the original prompt that triggered the tool calls
* @param chatResponse the chat response containing the tool calls to execute
* @return a Mono that emits the result of executing the tool calls
* @since 1.2.0
*/
@Override
public Mono<ToolExecutionResult> executeToolCallsAsync(Prompt prompt, ChatResponse chatResponse) {
return this.delegateToolCallingManager.executeToolCallsAsync(prompt, chatResponse);
}

}
Loading