Skip to content

Commit 8e8654e

Browse files
q-nathangrandericbottard
authored andcommitted
GH-4596: Handle candidates containing both text and tool calls in VertexAiGeminiChatModel
Fix #4596 Auto-cherry-pick to 1.0.x Signed-off-by: NathanGrand <nathangrand@quantexa.com> Signed-off-by: Eric Bottard <eric.bottard@broadcom.com>
1 parent ffe11b4 commit 8e8654e

File tree

2 files changed

+59
-29
lines changed

2 files changed

+59
-29
lines changed

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.Collection;
2121
import java.util.List;
2222
import java.util.Map;
23+
import java.util.stream.Collectors;
2324

2425
import com.fasterxml.jackson.annotation.JsonInclude;
2526
import com.fasterxml.jackson.annotation.JsonInclude.Include;
@@ -600,39 +601,27 @@ protected List<Generation> responseCandidateToGeneration(Candidate candidate) {
600601
.finishReason(candidateFinishReason.name())
601602
.build();
602603

603-
boolean isFunctionCall = candidate.getContent().getPartsList().stream().allMatch(Part::hasFunctionCall);
604+
List<Part> parts = candidate.getContent().getPartsList();
604605

605-
if (isFunctionCall) {
606-
List<AssistantMessage.ToolCall> assistantToolCalls = candidate.getContent()
607-
.getPartsList()
608-
.stream()
609-
.filter(part -> part.hasFunctionCall())
610-
.map(part -> {
611-
FunctionCall functionCall = part.getFunctionCall();
612-
var functionName = functionCall.getName();
613-
String functionArguments = structToJson(functionCall.getArgs());
614-
return new AssistantMessage.ToolCall("", "function", functionName, functionArguments);
615-
})
616-
.toList();
606+
List<AssistantMessage.ToolCall> assistantToolCalls = parts.stream().filter(Part::hasFunctionCall).map(part -> {
607+
FunctionCall functionCall = part.getFunctionCall();
608+
var functionName = functionCall.getName();
609+
String functionArguments = structToJson(functionCall.getArgs());
610+
return new AssistantMessage.ToolCall("", "function", functionName, functionArguments);
611+
}).toList();
617612

618-
AssistantMessage assistantMessage = AssistantMessage.builder()
619-
.content("")
620-
.properties(messageMetadata)
621-
.toolCalls(assistantToolCalls)
622-
.build();
613+
String text = parts.stream()
614+
.filter(part -> part.hasText() && !part.getText().isEmpty())
615+
.map(Part::getText)
616+
.collect(Collectors.joining(" "));
623617

624-
return List.of(new Generation(assistantMessage, chatGenerationMetadata));
625-
}
626-
else {
627-
List<Generation> generations = candidate.getContent()
628-
.getPartsList()
629-
.stream()
630-
.map(part -> AssistantMessage.builder().content(part.getText()).properties(messageMetadata).build())
631-
.map(assistantMessage -> new Generation(assistantMessage, chatGenerationMetadata))
632-
.toList();
618+
AssistantMessage assistantMessage = AssistantMessage.builder()
619+
.content(text)
620+
.properties(messageMetadata)
621+
.toolCalls(assistantToolCalls)
622+
.build();
633623

634-
return generations;
635-
}
624+
return List.of(new Generation(assistantMessage, chatGenerationMetadata));
636625
}
637626

638627
private ChatResponseMetadata toChatResponseMetadata(Usage usage) {

models/spring-ai-vertex-ai-gemini/src/test/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModelIT.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,47 @@ void jsonTextToolCallingTest() {
414414
assertThat(response).contains("2025-05-08T10:10:10+02:00[Europe/Berlin]");
415415
}
416416

417+
/**
418+
* See https://github.com/spring-projects/spring-ai/pull/4599
419+
*/
420+
@Test
421+
void testMixedPartsMessages() {
422+
VertexAiGeminiChatModel chatModelWithTools = VertexAiGeminiChatModel.builder()
423+
.vertexAI(vertexAiApi())
424+
.defaultOptions(VertexAiGeminiChatOptions.builder().model("gemini-2.5-pro").temperature(0.0).build())
425+
.build();
426+
427+
ChatClient chatClient = ChatClient.builder(chatModelWithTools).build();
428+
429+
// Create a prompt that will encourage gemini to explain why it is calling tools
430+
// as it does.
431+
AlarmTools alarmTools = new AlarmTools();
432+
String response = chatClient.prompt()
433+
.tools(new CurrentTimeTools(), alarmTools)
434+
.system("You MUST include reasoning when you issue tool calls.")
435+
.user("Set an alarm for an hour from now, and tell me what time that was for.")
436+
.call()
437+
.content();
438+
439+
assertThat(response).isEqualTo("I have set an alarm for 11:10 AM.");
440+
assertThat(alarmTools.getAlarm()).isEqualTo("2025-05-08T11:10:10+02:00");
441+
}
442+
443+
public static class AlarmTools {
444+
445+
private String alarm;
446+
447+
@Tool(description = "Set a user alarm for the given time, provided in ISO-8601 format")
448+
void setAlarm(String time) {
449+
this.alarm = time;
450+
}
451+
452+
public String getAlarm() {
453+
return this.alarm;
454+
}
455+
456+
}
457+
417458
/**
418459
* Tool class that returns a JSON array to test the jsonToStruct method's ability to
419460
* handle JSON arrays. This specifically tests the PR changes that improve the

0 commit comments

Comments
 (0)