Skip to content

Commit 3b96f85

Browse files
authored
[Chore]: Stream tokens vs characters in tool call parser tests (vllm-project#26513)
Signed-off-by: Ben Browning <bbrownin@redhat.com>
1 parent 23ad820 commit 3b96f85

File tree

6 files changed

+80
-41
lines changed

6 files changed

+80
-41
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
from transformers import AutoTokenizer
6+
7+
from vllm.transformers_utils.tokenizer import AnyTokenizer
8+
9+
10+
@pytest.fixture(scope="function")
11+
def default_tokenizer() -> AnyTokenizer:
12+
return AutoTokenizer.from_pretrained("gpt2")

tests/entrypoints/openai/tool_parsers/test_llama3_json_tool_parser.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,15 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import pytest
5-
from transformers import AutoTokenizer
65

76
from vllm.entrypoints.openai.protocol import ExtractedToolCallInformation
87
from vllm.entrypoints.openai.tool_parsers.llama_tool_parser import Llama3JsonToolParser
8+
from vllm.transformers_utils.tokenizer import AnyTokenizer
99

1010

1111
@pytest.fixture
12-
def parser():
13-
# Use a small tokenizer for testing
14-
tokenizer = AutoTokenizer.from_pretrained("gpt2")
15-
return Llama3JsonToolParser(tokenizer)
12+
def parser(default_tokenizer: AnyTokenizer):
13+
return Llama3JsonToolParser(default_tokenizer)
1614

1715

1816
def test_extract_tool_calls_simple(parser):

tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from vllm.entrypoints.openai.protocol import FunctionCall
1313
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
14+
from vllm.transformers_utils.tokenizer import AnyTokenizer
1415

1516
# Test cases similar to pythonic parser but with Llama4 specific format
1617
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
@@ -63,10 +64,9 @@
6364

6465

6566
@pytest.mark.parametrize("streaming", [True, False])
66-
def test_no_tool_call(streaming: bool):
67-
mock_tokenizer = MagicMock()
67+
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
6868
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
69-
mock_tokenizer
69+
default_tokenizer
7070
)
7171
model_output = "How can I help you today?"
7272

@@ -205,11 +205,13 @@ def test_no_tool_call(streaming: bool):
205205

206206
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
207207
def test_tool_call(
208-
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
208+
streaming: bool,
209+
model_output: str,
210+
expected_tool_calls: list[FunctionCall],
211+
default_tokenizer: AnyTokenizer,
209212
):
210-
mock_tokenizer = MagicMock()
211213
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
212-
mock_tokenizer
214+
default_tokenizer
213215
)
214216

215217
content, tool_calls = run_tool_extraction(
@@ -222,10 +224,9 @@ def test_tool_call(
222224
assert actual.function == expected
223225

224226

225-
def test_streaming_tool_call_with_large_steps():
226-
mock_tokenizer = MagicMock()
227+
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
227228
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
228-
mock_tokenizer
229+
default_tokenizer
229230
)
230231
model_output_deltas = [
231232
"<|python_start|>[get_weather(city='LA', metric='C'), "
@@ -245,11 +246,10 @@ def test_streaming_tool_call_with_large_steps():
245246

246247

247248
@pytest.mark.parametrize("streaming", [False])
248-
def test_regex_timeout_handling(streaming: bool):
249+
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
249250
"""test regex timeout is handled gracefully"""
250-
mock_tokenizer = MagicMock()
251251
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
252-
mock_tokenizer
252+
default_tokenizer
253253
)
254254

255255
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2

tests/entrypoints/openai/tool_parsers/test_olmo3_tool_parser.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from vllm.entrypoints.openai.protocol import FunctionCall
1313
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
14+
from vllm.transformers_utils.tokenizer import AnyTokenizer
1415

1516
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
1617
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
@@ -68,9 +69,10 @@
6869

6970

7071
@pytest.mark.parametrize("streaming", [True, False])
71-
def test_no_tool_call(streaming: bool):
72-
mock_tokenizer = MagicMock()
73-
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
72+
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
73+
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
74+
default_tokenizer
75+
)
7476
model_output = "How can I help you today?"
7577

7678
content, tool_calls = run_tool_extraction(
@@ -183,10 +185,14 @@ def test_no_tool_call(streaming: bool):
183185

184186
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
185187
def test_tool_call(
186-
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
188+
streaming: bool,
189+
model_output: str,
190+
expected_tool_calls: list[FunctionCall],
191+
default_tokenizer: AnyTokenizer,
187192
):
188-
mock_tokenizer = MagicMock()
189-
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
193+
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
194+
default_tokenizer
195+
)
190196

191197
content, tool_calls = run_tool_extraction(
192198
tool_parser, model_output, streaming=streaming
@@ -199,9 +205,10 @@ def test_tool_call(
199205
assert actual.function == expected
200206

201207

202-
def test_streaming_tool_call_with_large_steps():
203-
mock_tokenizer = MagicMock()
204-
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
208+
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
209+
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
210+
default_tokenizer
211+
)
205212
model_output_deltas = [
206213
"<function_calls>get_weather(city='San",
207214
" Francisco', metric='celsius')\n"
@@ -221,10 +228,11 @@ def test_streaming_tool_call_with_large_steps():
221228

222229

223230
@pytest.mark.parametrize("streaming", [False])
224-
def test_regex_timeout_handling(streaming: bool):
231+
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
225232
"""test regex timeout is handled gracefully"""
226-
mock_tokenizer = MagicMock()
227-
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(mock_tokenizer)
233+
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
234+
default_tokenizer
235+
)
228236

229237
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
230238

tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from vllm.entrypoints.openai.protocol import FunctionCall
1313
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
14+
from vllm.transformers_utils.tokenizer import AnyTokenizer
1415

1516
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
1617
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
@@ -60,10 +61,9 @@
6061

6162

6263
@pytest.mark.parametrize("streaming", [True, False])
63-
def test_no_tool_call(streaming: bool):
64-
mock_tokenizer = MagicMock()
64+
def test_no_tool_call(streaming: bool, default_tokenizer: AnyTokenizer):
6565
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
66-
mock_tokenizer
66+
default_tokenizer
6767
)
6868
model_output = "How can I help you today?"
6969

@@ -165,11 +165,13 @@ def test_no_tool_call(streaming: bool):
165165

166166
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
167167
def test_tool_call(
168-
streaming: bool, model_output: str, expected_tool_calls: list[FunctionCall]
168+
streaming: bool,
169+
model_output: str,
170+
expected_tool_calls: list[FunctionCall],
171+
default_tokenizer: AnyTokenizer,
169172
):
170-
mock_tokenizer = MagicMock()
171173
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
172-
mock_tokenizer
174+
default_tokenizer
173175
)
174176

175177
content, tool_calls = run_tool_extraction(
@@ -183,10 +185,9 @@ def test_tool_call(
183185
assert actual.function == expected
184186

185187

186-
def test_streaming_tool_call_with_large_steps():
187-
mock_tokenizer = MagicMock()
188+
def test_streaming_tool_call_with_large_steps(default_tokenizer: AnyTokenizer):
188189
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
189-
mock_tokenizer
190+
default_tokenizer
190191
)
191192
model_output_deltas = [
192193
"[get_weather(city='San",
@@ -207,11 +208,10 @@ def test_streaming_tool_call_with_large_steps():
207208

208209

209210
@pytest.mark.parametrize("streaming", [False])
210-
def test_regex_timeout_handling(streaming: bool):
211+
def test_regex_timeout_handling(streaming: bool, default_tokenizer: AnyTokenizer):
211212
"""test regex timeout is handled gracefully"""
212-
mock_tokenizer = MagicMock()
213213
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
214-
mock_tokenizer
214+
default_tokenizer
215215
)
216216

217217
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2

tests/entrypoints/openai/tool_parsers/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ToolCall,
1212
)
1313
from vllm.entrypoints.openai.tool_parsers import ToolParser
14+
from vllm.transformers_utils.tokenizer import AnyTokenizer
1415

1516

1617
class StreamingToolReconstructor:
@@ -110,12 +111,32 @@ def run_tool_extraction_nonstreaming(
110111
return tool_parser.extract_tool_calls(model_output, request)
111112

112113

114+
def split_string_into_token_deltas(tokenizer: AnyTokenizer, text: str) -> list[str]:
115+
# Split a string into a series of deltas using the provided tokenizer. Each
116+
# delta will be the string equivalent of a single token.
117+
token_ids = tokenizer.encode(text, add_special_tokens=False)
118+
previously_decoded_text = ""
119+
deltas = []
120+
for i in range(1, len(token_ids) + 1):
121+
current_tokens = token_ids[:i]
122+
current_text = tokenizer.decode(current_tokens)
123+
new_text = current_text[len(previously_decoded_text) :]
124+
previously_decoded_text = current_text
125+
deltas.append(new_text)
126+
return deltas
127+
128+
113129
def run_tool_extraction_streaming(
114130
tool_parser: ToolParser,
115131
model_deltas: Iterable[str],
116132
request: ChatCompletionRequest | None = None,
117133
assert_one_tool_per_delta: bool = True,
118134
) -> StreamingToolReconstructor:
135+
if isinstance(model_deltas, str):
136+
model_deltas = split_string_into_token_deltas(
137+
tool_parser.model_tokenizer, model_deltas
138+
)
139+
119140
request = request or ChatCompletionRequest(messages=[], model="test-model")
120141
reconstructor = StreamingToolReconstructor(
121142
assert_one_tool_per_delta=assert_one_tool_per_delta

0 commit comments

Comments
 (0)