Skip to content

Commit 24f5fd7

Browse files
committed
[Bugfix] Fix Llama3JsonToolParser to handle nested JSON and strings with
braces The previous regex-based approach had two issues: 1. Could only handle one level of JSON nesting 2. Would fail on valid JSON if string values contained braces (e.g., code snippets) Example that would fail: {"name": "search", "parameters": {"query": "find users with status {active}"}} This fix replaces the regex with an iterative parser that: - Tracks brace counts and string boundaries correctly - Handles escape sequences (\", \\) - Supports arbitrary nesting depth - Works with multiple JSONs separated by semicolons Added comprehensive regression tests for nested JSON, strings with braces, code snippets, and escaped quotes. Signed-off-by: ym820 <yikai.mao@outlook.com>
1 parent fa6201e commit 24f5fd7

File tree

2 files changed

+130
-22
lines changed

2 files changed

+130
-22
lines changed

tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,58 @@ def test_extract_tool_calls_very_deeply_nested_json(parser):
173173
import json
174174
args = json.loads(result.tool_calls[0].function.arguments)
175175
assert args["level1"]["level2"]["level3"]["value"] == "deep"
176+
177+
178+
def test_extract_tool_calls_with_braces_in_strings(parser):
179+
# Test with braces inside string values
180+
# This is a regression test for string-awareness in JSON extraction
181+
model_output = (
182+
'{"name": "search", '
183+
'"parameters": {"query": "find users with status {active}"}}'
184+
)
185+
result = parser.extract_tool_calls(model_output, None)
186+
187+
assert result.tools_called is True
188+
assert len(result.tool_calls) == 1
189+
assert result.tool_calls[0].function.name == "search"
190+
191+
# Verify the string with braces is captured correctly
192+
import json
193+
args = json.loads(result.tool_calls[0].function.arguments)
194+
assert args["query"] == "find users with status {active}"
195+
196+
197+
def test_extract_tool_calls_with_code_snippets(parser):
198+
# Test with code snippets containing braces
199+
model_output = (
200+
'{"name": "code_tool", '
201+
'"parameters": {"snippet": "function() { return {}; }"}}'
202+
)
203+
result = parser.extract_tool_calls(model_output, None)
204+
205+
assert result.tools_called is True
206+
assert len(result.tool_calls) == 1
207+
assert result.tool_calls[0].function.name == "code_tool"
208+
209+
# Verify the code snippet is captured correctly
210+
import json
211+
args = json.loads(result.tool_calls[0].function.arguments)
212+
assert args["snippet"] == "function() { return {}; }"
213+
214+
215+
def test_extract_tool_calls_with_escaped_quotes(parser):
216+
# Test with escaped quotes in strings
217+
model_output = (
218+
'{"name": "test", '
219+
'"parameters": {"text": "He said \\"hello {world}\\""}}'
220+
)
221+
result = parser.extract_tool_calls(model_output, None)
222+
223+
assert result.tools_called is True
224+
assert len(result.tool_calls) == 1
225+
assert result.tool_calls[0].function.name == "test"
226+
227+
# Verify escaped quotes are handled correctly
228+
import json
229+
args = json.loads(result.tool_calls[0].function.arguments)
230+
assert args["text"] == 'He said "hello {world}"'

vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from collections.abc import Sequence
66

77
import partial_json_parser
8-
import regex as re
98
from partial_json_parser.core.options import Allow
109
from transformers import PreTrainedTokenizerBase
1110

@@ -56,13 +55,6 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase):
5655
self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[
5756
0
5857
]
59-
# Updated regex to match multiple JSONs separated by semicolons
60-
# This pattern uses recursion to handle arbitrarily nested JSON objects
61-
# (?R) is a recursive pattern that matches the entire pattern again
62-
self.tool_call_regex = re.compile(
63-
r"\{(?:[^{}]|(?R))*\}(?:\s*;\s*\{(?:[^{}]|(?R))*\})*",
64-
re.DOTALL,
65-
)
6658

6759
def extract_tool_calls(
6860
self, model_output: str, request: ChatCompletionRequest
@@ -78,23 +70,26 @@ def extract_tool_calls(
7870
tools_called=False, tool_calls=[], content=model_output
7971
)
8072

81-
# Find JSON object(s) in the text using regex
82-
match = self.tool_call_regex.search(model_output)
83-
if not match:
84-
return ExtractedToolCallInformation(
85-
tools_called=False, tool_calls=[], content=model_output
86-
)
87-
8873
try:
89-
json_str = match.group(0)
90-
# Split by semicolon and strip whitespace
91-
json_objects = [obj.strip() for obj in json_str.split(";")]
74+
# Find the start of JSON object
75+
start_idx = model_output.find("{")
76+
if start_idx == -1:
77+
return ExtractedToolCallInformation(
78+
tools_called=False, tool_calls=[], content=model_output
79+
)
80+
81+
# Use iterative parsing with brace counting to extract JSON objects
82+
# This handles strings with braces correctly by tracking string boundaries
83+
json_objects = self._extract_json_objects(model_output[start_idx:])
84+
85+
if not json_objects:
86+
return ExtractedToolCallInformation(
87+
tools_called=False, tool_calls=[], content=model_output
88+
)
9289

9390
tool_calls: list[ToolCall] = []
94-
for json_obj in json_objects:
95-
if not json_obj: # Skip empty strings
96-
continue
97-
obj = json.loads(json_obj)
91+
for json_str in json_objects:
92+
obj = json.loads(json_str)
9893
tool_calls.append(
9994
ToolCall(
10095
type="function",
@@ -122,6 +117,64 @@ def extract_tool_calls(
122117
tools_called=False, tool_calls=[], content=model_output
123118
)
124119

120+
def _extract_json_objects(self, text: str) -> list[str]:
121+
"""
122+
Extract JSON objects from text using brace counting with string awareness.
123+
Handles nested JSON and strings containing braces correctly.
124+
Supports multiple JSONs separated by semicolons.
125+
"""
126+
json_objects = []
127+
i = 0
128+
129+
while i < len(text):
130+
# Skip whitespace and semicolons
131+
while i < len(text) and text[i] in " \t\n\r;":
132+
i += 1
133+
134+
if i >= len(text) or text[i] != "{":
135+
break
136+
137+
# Track braces and string state
138+
brace_count = 0
139+
in_string = False
140+
escape_next = False
141+
start = i
142+
143+
while i < len(text):
144+
char = text[i]
145+
146+
if escape_next:
147+
escape_next = False
148+
i += 1
149+
continue
150+
151+
if char == "\\":
152+
escape_next = True
153+
i += 1
154+
continue
155+
156+
if char == '"' and not in_string:
157+
in_string = True
158+
elif char == '"' and in_string:
159+
in_string = False
160+
elif char == "{" and not in_string:
161+
brace_count += 1
162+
elif char == "}" and not in_string:
163+
brace_count -= 1
164+
if brace_count == 0:
165+
# Found complete JSON object
166+
json_objects.append(text[start : i + 1])
167+
i += 1
168+
break
169+
170+
i += 1
171+
172+
# If we didn't find a complete object, stop
173+
if brace_count != 0:
174+
break
175+
176+
return json_objects
177+
125178
def extract_tool_calls_streaming(
126179
self,
127180
previous_text: str,

0 commit comments

Comments
 (0)