Skip to content

Commit 22d3b2a

Browse files
authored
Fix Qwen rendering and parsing for tool calls (#58)
1 parent 9aedece commit 22d3b2a

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

tinker_cookbook/renderers.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -340,13 +340,20 @@ def _render_message(self, idx: int, message: Message) -> tuple[list[int], list[i
340340
# <think> in the assistant messages, we so don't need to re-add it in those cases.
341341
ob_str += "<think>\n"
342342
# Observation (prompt) part
343-
ac_str = f"{ac_content}<|im_end|>"
343+
if "tool_calls" in message:
344+
ac_content += "\n".join(
345+
[
346+
f"<tool_call>\n{json.dumps(tool_call)}\n</tool_call>"
347+
for tool_call in message["tool_calls"]
348+
]
349+
)
350+
ac_content += "<|im_end|>"
344351
# Action part
345352
ac_tail_str = "" # No action tail needed for Qwen format
346353
# Action part that's only included in the last message in SFT
347354
return (
348355
self.tokenizer.encode(ob_str, add_special_tokens=False),
349-
self.tokenizer.encode(ac_str, add_special_tokens=False),
356+
self.tokenizer.encode(ac_content, add_special_tokens=False),
350357
self.tokenizer.encode(ac_tail_str, add_special_tokens=False),
351358
)
352359

@@ -409,11 +416,10 @@ def parse_response(self, response: list[int]) -> tuple[Message, bool]:
409416
if not parse_success:
410417
return assistant_message, False
411418

412-
# NOTE:
413-
# we use the <function_call>...</function_call> tag to wrap the tool call.
414-
match = re.search(
415-
r"<function_call>(.*?)</function_call>", assistant_message["content"], re.DOTALL
416-
)
419+
# Follow Qwen docs and Qwen-Agent's tool calling prompt to use <tool_call>...</tool_call> tags to wrap the tool call.
420+
# - https://qwen.readthedocs.io/en/latest/getting_started/concepts.html#tool-calling
421+
# - https://github.com/QwenLM/Qwen-Agent/blob/main/qwen_agent/llm/fncall_prompts/nous_fncall_prompt.py#L279-L282
422+
match = re.search(r"<tool_call>(.*?)</tool_call>", assistant_message["content"], re.DOTALL)
417423
if match:
418424
tool_calls = self._parse_tool_call(match.group(1))
419425
if tool_calls is None:

0 commit comments

Comments
 (0)