Skip to content

Commit e5be2fa

Browse files
committed
Fix streaming tool call merge in MessageAggregator
- Added mergeToolCalls() method to aggregate tool call fragments by ID - Preserves name from first chunk and concatenates arguments - Added MessageAggregatorQwenTest to reproduce Qwen streaming pattern Signed-off-by: kimtaewoong <ktw2172@gmail.com>
1 parent bd1834d commit e5be2fa

File tree

2 files changed

+137
-1
lines changed

2 files changed

+137
-1
lines changed

spring-ai-model/src/main/java/org/springframework/ai/chat/model/MessageAggregator.java

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.util.ArrayList;
2020
import java.util.HashMap;
21+
import java.util.LinkedHashMap;
2122
import java.util.List;
2223
import java.util.Map;
2324
import java.util.concurrent.atomic.AtomicReference;
@@ -104,7 +105,7 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
104105
}
105106
AssistantMessage outputMessage = chatResponse.getResult().getOutput();
106107
if (!CollectionUtils.isEmpty(outputMessage.getToolCalls())) {
107-
toolCallsRef.get().addAll(outputMessage.getToolCalls());
108+
mergeToolCalls(toolCallsRef.get(), outputMessage.getToolCalls());
108109
}
109110

110111
}
@@ -188,6 +189,51 @@ public Flux<ChatResponse> aggregate(Flux<ChatResponse> fluxChatResponse,
188189
}).doOnError(e -> logger.error("Aggregation Error", e));
189190
}
190191

192+
/**
193+
* Merge tool calls by id to handle streaming responses where tool call data is split
194+
* across multiple chunks. This is common in OpenAI-compatible APIs like Qwen, where
195+
* the first chunk contains the function name and subsequent chunks contain only
196+
* arguments.
197+
* @param existingToolCalls the list of existing tool calls to merge into
198+
* @param newToolCalls the new tool calls to merge
199+
*/
200+
private void mergeToolCalls(List<ToolCall> existingToolCalls, List<ToolCall> newToolCalls) {
201+
Map<String, ToolCall> toolCallMap = new LinkedHashMap<>();
202+
203+
// Build map from existing tool calls
204+
for (ToolCall existing : existingToolCalls) {
205+
toolCallMap.put(existing.id(), existing);
206+
}
207+
208+
// Merge new tool calls
209+
for (ToolCall newCall : newToolCalls) {
210+
if (toolCallMap.containsKey(newCall.id())) {
211+
// Merge with existing tool call
212+
ToolCall existing = toolCallMap.get(newCall.id());
213+
214+
// Use non-empty name, prefer new if both present
215+
String mergedName = StringUtils.hasText(newCall.name()) ? newCall.name() : existing.name();
216+
217+
// Use non-empty type, prefer new if both present
218+
String mergedType = StringUtils.hasText(newCall.type()) ? newCall.type() : existing.type();
219+
220+
// Concatenate arguments
221+
String mergedArgs = (existing.arguments() != null ? existing.arguments() : "")
222+
+ (newCall.arguments() != null ? newCall.arguments() : "");
223+
224+
toolCallMap.put(newCall.id(), new ToolCall(newCall.id(), mergedType, mergedName, mergedArgs));
225+
}
226+
else {
227+
// New tool call
228+
toolCallMap.put(newCall.id(), newCall);
229+
}
230+
}
231+
232+
// Update the existing list
233+
existingToolCalls.clear();
234+
existingToolCalls.addAll(toolCallMap.values());
235+
}
236+
191237
public record DefaultUsage(Integer promptTokens, Integer completionTokens, Integer totalTokens) implements Usage {
192238

193239
@Override
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.chat.model;
18+
19+
import java.util.List;
20+
import java.util.concurrent.atomic.AtomicReference;
21+
22+
import org.junit.jupiter.api.Test;
23+
import reactor.core.publisher.Flux;
24+
25+
import org.springframework.ai.chat.messages.AssistantMessage;
26+
27+
import static org.assertj.core.api.Assertions.assertThat;
28+
29+
/**
30+
* Tests for {@link MessageAggregator} with Qwen streaming tool call pattern.
31+
* @author Taewoong Kim
32+
*/
33+
class MessageAggregatorQwenTest {
34+
35+
private final MessageAggregator messageAggregator = new MessageAggregator();
36+
37+
/**
38+
* Test based on actual Qwen streaming response from OpenRouter. Qwen sends tool name
39+
* only in the first chunk, subsequent chunks have empty name.
40+
*/
41+
@Test
42+
void shouldMergeToolCallsFromQwenStreamingResponse() {
43+
// Given: Qwen streaming chunks (actual pattern from curl test)
44+
// Chunk 1: {"tool_calls":[{"id":"...","function":{"name":"getCurrentWeather"}}]}
45+
ChatResponse chunk1 = new ChatResponse(List.of(new Generation(AssistantMessage.builder()
46+
.toolCalls(List.of(new AssistantMessage.ToolCall("chatcmpl-tool-123", "function", "getCurrentWeather", "")))
47+
.build())));
48+
49+
// Chunk 2: {"tool_calls":[{"index":0,"function":{"arguments":"{\"location\":
50+
// \""}}]}
51+
ChatResponse chunk2 = new ChatResponse(List.of(new Generation(AssistantMessage.builder()
52+
.toolCalls(List.of(new AssistantMessage.ToolCall("chatcmpl-tool-123", "function", "", "{\"location\": \"")))
53+
.build())));
54+
55+
// Chunk 3: {"tool_calls":[{"index":0,"function":{"arguments":"Se"}}]}
56+
ChatResponse chunk3 = new ChatResponse(List.of(new Generation(AssistantMessage.builder()
57+
.toolCalls(List.of(new AssistantMessage.ToolCall("chatcmpl-tool-123", "function", "", "Se")))
58+
.build())));
59+
60+
// Chunk 4: {"tool_calls":[{"index":0,"function":{"arguments":"oul"}}]}
61+
ChatResponse chunk4 = new ChatResponse(List.of(new Generation(AssistantMessage.builder()
62+
.toolCalls(List.of(new AssistantMessage.ToolCall("chatcmpl-tool-123", "function", "", "oul")))
63+
.build())));
64+
65+
// Chunk 5: {"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]}
66+
ChatResponse chunk5 = new ChatResponse(List.of(new Generation(AssistantMessage.builder()
67+
.toolCalls(List.of(new AssistantMessage.ToolCall("chatcmpl-tool-123", "function", "", "\"}")))
68+
.build())));
69+
70+
Flux<ChatResponse> flux = Flux.just(chunk1, chunk2, chunk3, chunk4, chunk5);
71+
72+
// When: Aggregate the streaming responses
73+
AtomicReference<ChatResponse> finalResponse = new AtomicReference<>();
74+
this.messageAggregator.aggregate(flux, finalResponse::set).blockLast();
75+
76+
// Then: Verify the tool call was properly merged
77+
assertThat(finalResponse.get()).isNotNull();
78+
List<AssistantMessage.ToolCall> toolCalls = finalResponse.get().getResult().getOutput().getToolCalls();
79+
80+
assertThat(toolCalls).hasSize(1);
81+
AssistantMessage.ToolCall mergedToolCall = toolCalls.get(0);
82+
83+
// The bug was: toolName would be empty string, causing "toolName cannot be null
84+
// or empty"
85+
// After fix: name should be preserved from first chunk
86+
assertThat(mergedToolCall.name()).isEqualTo("getCurrentWeather");
87+
assertThat(mergedToolCall.arguments()).isEqualTo("{\"location\": \"Seoul\"}");
88+
}
89+
90+
}

0 commit comments

Comments
 (0)