Skip to content

Commit 72b2505

Browse files
2ez4bzWanli-Jiang
authored andcommitted
[None][fixes] Add tool call parsing fixes and Qwen3 coder parser
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
1 parent 972c21c commit 72b2505

File tree

7 files changed

+960
-70
lines changed

7 files changed

+960
-70
lines changed

tensorrt_llm/serve/chat_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import uuid
23
from functools import partial
34
from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal,
@@ -185,6 +186,36 @@ def parse_chat_message_content(
185186
content,
186187
mm_data_tracker,
187188
)
189+
if role == "assistant":
190+
result.update(_parse_assistant_message_content(message))
191+
elif role == "tool":
192+
result.update(_parse_tool_message_content(message))
193+
return result
194+
195+
196+
# Adapted from: https://github.com/vllm-project/vllm/blob/4574d48bab9c4e38b7c0a830eeefc8f0980e8c58/vllm/entrypoints/chat_utils.py#L1406
197+
def _parse_assistant_message_content(message: Dict[str, Any]) -> Dict[str, Any]:
198+
result = {}
199+
tool_calls = message.get("tool_calls")
200+
if tool_calls is not None:
201+
result["tool_calls"] = []
202+
for item in tool_calls:
203+
if content := item["function"].get("arguments"):
204+
if isinstance(content, str):
205+
item["function"]["arguments"] = json.loads(content)
206+
else:
207+
item["function"]["arguments"] = content
208+
else:
209+
item["function"]["arguments"] = {}
210+
result["tool_calls"].append(item)
211+
212+
return result
213+
214+
215+
def _parse_tool_message_content(message: Dict[str, Any]) -> Dict[str, Any]:
216+
result = {}
217+
if "tool_call_id" in message:
218+
result["tool_call_id"] = message["tool_call_id"]
188219
return result
189220

190221

tensorrt_llm/serve/openai_protocol.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,12 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
396396

397397
class CustomChatCompletionMessageParam(TypedDict, total=False):
398398
"""Enables custom roles in the Chat Completion API."""
399+
400+
# This is so custom fields not in any of the `ChatCompletionMessage<XYZ>Param` defined by OpenAI
401+
# are still allowed.
402+
# Examples include: assistant messages with `reasoning` / `reasoning_content`.
403+
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
404+
399405
role: Required[str]
400406
"""The role of the message's author."""
401407

Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
# Adapted from: https://raw.githubusercontent.com/sgl-project/sglang/d8fcbaa38da95201914a1277971044ee66837b26/python/sglang/srt/function_call/qwen3_coder_detector.py
2+
3+
import ast
4+
import html
5+
import json
6+
import re
7+
from typing import Any, Dict, List, Tuple
8+
9+
from tensorrt_llm.logger import logger
10+
from tensorrt_llm.serve.openai_protocol import ChatCompletionToolsParam as Tool
11+
from tensorrt_llm.serve.tool_parser.base_tool_parser import BaseToolParser
12+
from tensorrt_llm.serve.tool_parser.core_types import (
13+
StreamingParseResult,
14+
ToolCallItem,
15+
_GetInfoFunc,
16+
)
17+
18+
19+
def _safe_val(raw: str) -> Any:
20+
raw = html.unescape(raw.strip())
21+
try:
22+
return json.loads(raw)
23+
except Exception:
24+
try:
25+
return ast.literal_eval(raw)
26+
except Exception:
27+
return raw
28+
29+
30+
class Qwen3CoderToolParser(BaseToolParser):
31+
"""Tool parser for Qwen 3 models.
32+
33+
Assumes function call format:
34+
<tool_call>
35+
<function=execute_bash>
36+
<parameter=command>
37+
pwd && ls
38+
</parameter>
39+
</function>
40+
</tool_call>
41+
"""
42+
43+
def __init__(self):
44+
super().__init__()
45+
self.tool_call_start_token: str = "<tool_call>"
46+
self.tool_call_end_token: str = "</tool_call>"
47+
self.tool_call_prefix: str = "<function="
48+
self.tool_call_regex = re.compile(
49+
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL
50+
)
51+
self.tool_call_function_regex = re.compile(
52+
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
53+
)
54+
self.tool_call_parameter_regex = re.compile(
55+
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL
56+
)
57+
self._buf: str = ""
58+
59+
# Streaming state variables
60+
self._current_function_name: str = ""
61+
self._current_parameters: Dict[str, Any] = {}
62+
self._streamed_parameters: Dict[
63+
str, str
64+
] = {} # Track what parameter content we've streamed
65+
self._in_tool_call: bool = False
66+
self._function_name_sent: bool = False
67+
68+
def has_tool_call(self, text: str) -> bool:
69+
return self.tool_call_start_token in text
70+
71+
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
72+
normal, calls = self._extract(text, tools)
73+
return StreamingParseResult(normal_text=normal, calls=calls)
74+
75+
def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult:
76+
self._buf += new_text
77+
normal = ""
78+
calls: List[ToolCallItem] = []
79+
80+
# Build tool indices for validation
81+
if not hasattr(self, "_tool_indices"):
82+
self._tool_indices = self._get_tool_indices(tools)
83+
84+
while True:
85+
# If we're not in a tool call and don't see a start token, return normal text
86+
if not self._in_tool_call and self.tool_call_start_token not in self._buf:
87+
normal += self._buf
88+
self._buf = ""
89+
break
90+
91+
# Look for tool call start
92+
if not self._in_tool_call:
93+
s = self._buf.find(self.tool_call_start_token)
94+
if s == -1:
95+
normal += self._buf
96+
self._buf = ""
97+
break
98+
99+
normal += self._buf[:s]
100+
self._buf = self._buf[s:]
101+
102+
self._in_tool_call = True
103+
self._function_name_sent = False
104+
self._current_function_name = ""
105+
self._current_parameters = {}
106+
self._streamed_parameters = {}
107+
108+
# Remove the start token
109+
self._buf = self._buf[len(self.tool_call_start_token) :]
110+
continue
111+
112+
# We're in a tool call, try to parse function name if not sent yet
113+
if not self._function_name_sent:
114+
# Look for function name pattern: <function=name>
115+
function_match = re.search(r"<function=([^>]+)>", self._buf)
116+
if function_match:
117+
function_name = function_match.group(1).strip()
118+
119+
# Validate function name
120+
if function_name in self._tool_indices:
121+
self._current_function_name = function_name
122+
self._function_name_sent = True
123+
124+
# Initialize tool call tracking
125+
if self.current_tool_id == -1:
126+
self.current_tool_id = 0
127+
128+
# Ensure tracking arrays are large enough
129+
while len(self.prev_tool_call_arr) <= self.current_tool_id:
130+
self.prev_tool_call_arr.append({})
131+
while len(self.streamed_args_for_tool) <= self.current_tool_id:
132+
self.streamed_args_for_tool.append("")
133+
134+
# Store tool call info
135+
self.prev_tool_call_arr[self.current_tool_id] = {
136+
"name": function_name,
137+
"arguments": {},
138+
}
139+
140+
# Send tool name with empty parameters
141+
calls.append(
142+
ToolCallItem(
143+
tool_index=self.current_tool_id,
144+
name=function_name,
145+
parameters="",
146+
)
147+
)
148+
149+
# Remove the processed function declaration
150+
self._buf = self._buf[function_match.end() :]
151+
continue
152+
else:
153+
# Invalid function name, reset state
154+
logger.warning(f"Invalid function name: {function_name}")
155+
self._reset_streaming_state()
156+
normal += self._buf
157+
self._buf = ""
158+
break
159+
else:
160+
# Function name not complete yet, wait for more text
161+
break
162+
163+
# Parse parameters incrementally
164+
if self._function_name_sent:
165+
# Process parameters and get any calls to emit
166+
parameter_calls = self._parse_and_stream_parameters(self._buf)
167+
calls.extend(parameter_calls)
168+
169+
# Check if tool call is complete
170+
if self.tool_call_end_token in self._buf:
171+
end_pos = self._buf.find(self.tool_call_end_token)
172+
173+
# Add closing brace to complete the JSON object
174+
current_streamed = self.streamed_args_for_tool[self.current_tool_id]
175+
if current_streamed:
176+
# Count opening and closing braces to check if JSON is complete
177+
open_braces = current_streamed.count("{")
178+
close_braces = current_streamed.count("}")
179+
if open_braces > close_braces:
180+
calls.append(
181+
ToolCallItem(
182+
tool_index=self.current_tool_id,
183+
name=None,
184+
parameters="}",
185+
)
186+
)
187+
self.streamed_args_for_tool[self.current_tool_id] = (
188+
current_streamed + "}"
189+
)
190+
191+
# Complete the tool call
192+
self._buf = self._buf[end_pos + len(self.tool_call_end_token) :]
193+
self._reset_streaming_state()
194+
self.current_tool_id += 1
195+
continue
196+
else:
197+
# Tool call not complete yet, wait for more text
198+
break
199+
200+
return StreamingParseResult(normal_text=normal, calls=calls)
201+
202+
def _parse_and_stream_parameters(self, text_to_parse: str) -> List[ToolCallItem]:
203+
"""Parse complete parameter blocks from text and return any tool call items to emit.
204+
205+
This method:
206+
1. Finds all complete <parameter> blocks
207+
2. Parses them into a dictionary
208+
3. Compares with current parameters and generates diff if needed
209+
4. Updates internal state
210+
211+
Args:
212+
text_to_parse: The text to search for parameter blocks
213+
214+
Returns:
215+
List of ToolCallItem objects to emit (may be empty)
216+
"""
217+
calls: List[ToolCallItem] = []
218+
219+
# Find all complete parameter patterns
220+
param_matches = list(
221+
re.finditer(r"<parameter=([^>]+)>(.*?)</parameter>", text_to_parse, re.DOTALL)
222+
)
223+
224+
# Build new parameters dictionary
225+
new_params = {}
226+
for match in param_matches:
227+
param_name = match.group(1).strip()
228+
param_value = match.group(2)
229+
new_params[param_name] = _safe_val(param_value)
230+
231+
# Calculate parameter diff to stream with proper incremental JSON building
232+
if new_params != self._current_parameters:
233+
previous_args_json = self.streamed_args_for_tool[self.current_tool_id]
234+
235+
# Build incremental JSON properly
236+
if not self._current_parameters:
237+
# First parameter(s) - start JSON object but don't close it yet
238+
items = []
239+
for key, value in new_params.items():
240+
items.append(
241+
f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}"
242+
)
243+
json_fragment = "{" + ", ".join(items)
244+
245+
calls.append(
246+
ToolCallItem(
247+
tool_index=self.current_tool_id,
248+
name=None,
249+
parameters=json_fragment,
250+
)
251+
)
252+
self.streamed_args_for_tool[self.current_tool_id] = json_fragment
253+
254+
else:
255+
# Additional parameters - add them incrementally
256+
new_keys = set(new_params.keys()) - set(self._current_parameters.keys())
257+
if new_keys:
258+
# Build the continuation part (no closing brace yet)
259+
continuation_parts = []
260+
for key in new_keys:
261+
value = new_params[key]
262+
continuation_parts.append(
263+
f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}"
264+
)
265+
266+
json_fragment = ", " + ", ".join(continuation_parts)
267+
268+
calls.append(
269+
ToolCallItem(
270+
tool_index=self.current_tool_id,
271+
name=None,
272+
parameters=json_fragment,
273+
)
274+
)
275+
self.streamed_args_for_tool[self.current_tool_id] = (
276+
previous_args_json + json_fragment
277+
)
278+
279+
# Update current state
280+
self._current_parameters = new_params
281+
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params
282+
283+
return calls
284+
285+
def _reset_streaming_state(self):
286+
"""Reset streaming state for the next tool call."""
287+
self._in_tool_call = False
288+
self._function_name_sent = False
289+
self._current_function_name = ""
290+
self._current_parameters = {}
291+
self._streamed_parameters = {}
292+
self.current_tool_name_sent = False
293+
294+
def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]:
295+
normal_parts: List[str] = []
296+
calls: List[ToolCallItem] = []
297+
cursor = 0
298+
while True:
299+
s = text.find(self.tool_call_start_token, cursor)
300+
if s == -1:
301+
normal_parts.append(text[cursor:])
302+
break
303+
normal_parts.append(text[cursor:s])
304+
e = text.find(self.tool_call_end_token, s)
305+
if e == -1:
306+
normal_parts.append(text[s:])
307+
break
308+
block = text[s : e + len(self.tool_call_end_token)]
309+
cursor = e + len(self.tool_call_end_token)
310+
calls.extend(self._parse_block(block, tools))
311+
return "".join(normal_parts), calls
312+
313+
def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]:
314+
res: List[ToolCallItem] = []
315+
for m in self.tool_call_function_regex.findall(block):
316+
txt = m[0] if m[0] else m[1]
317+
if ">" not in txt:
318+
continue
319+
idx = txt.index(">")
320+
fname = txt[:idx].strip()
321+
body = txt[idx + 1 :]
322+
params: Dict[str, Any] = {}
323+
for pm in self.tool_call_parameter_regex.findall(body):
324+
ptxt = pm[0] if pm[0] else pm[1]
325+
if ">" not in ptxt:
326+
continue
327+
pidx = ptxt.index(">")
328+
pname = ptxt[:pidx].strip()
329+
pval = ptxt[pidx + 1 :].lstrip("\n").rstrip("\n")
330+
params[pname] = _safe_val(pval)
331+
raw = {"name": fname, "arguments": params}
332+
try:
333+
# TODO: fix idx in function call, the index for a function
334+
# call will always be -1 in parse_base_json
335+
res.extend(self.parse_base_json(raw, tools))
336+
except Exception:
337+
logger.warning("invalid tool call for %s dropped", fname)
338+
return res
339+
340+
def supports_structural_tag(self) -> bool:
341+
return False
342+
343+
def structure_info(self) -> _GetInfoFunc:
344+
raise NotImplementedError

0 commit comments

Comments
 (0)