Skip to content

Commit d9721b9

Browse files
committed
Add chat template.
1 parent 9dbbe5f commit d9721b9

File tree

4 files changed

+251
-0
lines changed

4 files changed

+251
-0
lines changed

ads/llm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
ChatOCIModelDeploymentVLLM,
1616
ChatOCIModelDeploymentTGI,
1717
)
18+
from ads.llm.chat_template import ChatTemplates
1819
except ImportError as ex:
1920
if ex.name == "langchain":
2021
raise ImportError(

ads/llm/chat_template.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
8+
import os
9+
10+
11+
class ChatTemplates:
12+
"""Contains chat templates."""
13+
14+
@staticmethod
15+
def _read_template(filename):
16+
with open(
17+
os.path.join(os.path.dirname(__file__), "templates", filename),
18+
mode="r",
19+
encoding="utf-8",
20+
) as f:
21+
return f.read()
22+
23+
@staticmethod
24+
def mistral():
25+
"""Chat template for auto tool calling with Mistral model deploy with vLLM."""
26+
return ChatTemplates._read_template("tool_chat_template_mistral_parallel.jinja")
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
{%- macro json_to_python_type(json_spec) %}
2+
{%- set basic_type_map = {
3+
"string": "str",
4+
"number": "float",
5+
"integer": "int",
6+
"boolean": "bool"
7+
} %}
8+
9+
{%- if basic_type_map[json_spec.type] is defined %}
10+
{{- basic_type_map[json_spec.type] }}
11+
{%- elif json_spec.type == "array" %}
12+
{{- "list[" + json_to_python_type(json_spec|items) + "]" }}
13+
{%- elif json_spec.type == "object" %}
14+
{%- if json_spec.additionalProperties is defined %}
15+
{{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']' }}
16+
{%- else %}
17+
{{- "dict" }}
18+
{%- endif %}
19+
{%- elif json_spec.type is iterable %}
20+
{{- "Union[" }}
21+
{%- for t in json_spec.type %}
22+
{{- json_to_python_type({"type": t}) }}
23+
{%- if not loop.last %}
24+
{{- "," }}
25+
{%- endif %}
26+
{%- endfor %}
27+
{{- "]" }}
28+
{%- else %}
29+
{{- "Any" }}
30+
{%- endif %}
31+
{%- endmacro %}
32+
33+
34+
{{- bos_token }}
35+
{{- "<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> " }}
36+
{%- if tools is iterable and tools | length > 0 %}
37+
{%- for tool in tools %}
38+
{%- if tool.function is defined %}
39+
{%- set tool = tool.function %}
40+
{%- endif %}
41+
{{- '{"type": "function", "function": ' }}
42+
{{- '{"name": "' + tool.name + '", ' }}
43+
{{- '"description": "' + tool.name + '(' }}
44+
{%- for param_name, param_fields in tool.parameters.properties|items %}
45+
{{- param_name + ": " + json_to_python_type(param_fields) }}
46+
{%- if not loop.last %}
47+
{{- ", " }}
48+
{%- endif %}
49+
{%- endfor %}
50+
{{- ")" }}
51+
{%- if tool.return is defined %}
52+
{{- " -> " + json_to_python_type(tool.return) }}
53+
{%- endif %}
54+
{{- " - " + tool.description + "\n\n" }}
55+
{%- for param_name, param_fields in tool.parameters.properties|items %}
56+
{%- if loop.first %}
57+
{{- " Args:\n" }}
58+
{%- endif %}
59+
{{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }}
60+
{%- endfor %}
61+
{%- if tool.return is defined and tool.return.description is defined %}
62+
{{- "\n Returns:\n " + tool.return.description }}
63+
{%- endif %}
64+
{{- '"' }}
65+
{{- ', "parameters": ' }}
66+
{%- if tool.parameters.properties | length == 0 %}
67+
{{- "{}" }}
68+
{%- else %}
69+
{{- tool.parameters|tojson }}
70+
{%- endif %}
71+
{{- "}" }}
72+
{%- if not loop.last %}
73+
{{- "\n" }}
74+
{%- endif %}
75+
{%- endfor %}
76+
{%- endif %}
77+
{{- " </tools>" }}
78+
{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}}
79+
' }}
80+
{{- "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
81+
" }}
82+
{{- "<tool_call>
83+
" }}
84+
{{- '{"name": <function-name>, "arguments": <args-dict>}
85+
' }}
86+
{{- '</tool_call><|im_end|>' }}
87+
{%- for message in messages %}
88+
{%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %}
89+
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
90+
{%- elif message.role == "assistant" and message.tool_calls is defined %}
91+
{{- '<|im_start|>' + message.role }}
92+
{%- for tool_call in message.tool_calls %}
93+
{{- '\n<tool_call>\n' }}
94+
{%- if tool_call.function is defined %}
95+
{%- set tool_call = tool_call.function %}
96+
{%- endif %}
97+
{{- '{' }}
98+
{{- '"name": "' }}
99+
{{- tool_call.name }}
100+
{{- '"' }}
101+
{%- if tool_call.arguments is defined %}
102+
{{- ', ' }}
103+
{{- '"arguments": ' }}
104+
{{- tool_call.arguments|tojson }}
105+
{%- endif %}
106+
{{- '}' }}
107+
{{- '\n</tool_call>' }}
108+
{%- endfor %}
109+
{{- '<|im_end|>\n' }}
110+
{%- elif message.role == "tool" %}
111+
{%- if loop.previtem and loop.previtem.role != "tool" %}
112+
{{- '<|im_start|>tool\n' }}
113+
{%- endif %}
114+
{{- '<tool_response>\n' }}
115+
{{- message.content }}
116+
{%- if not loop.last %}
117+
{{- '\n</tool_response>\n' }}
118+
{%- else %}
119+
{{- '\n</tool_response>' }}
120+
{%- endif %}
121+
{%- if not loop.last and loop.nextitem.role != "tool" %}
122+
{{- '<|im_end|>' }}
123+
{%- elif loop.last %}
124+
{{- '<|im_end|>' }}
125+
{%- endif %}
126+
{%- endif %}
127+
{%- endfor %}
128+
{%- if add_generation_prompt %}
129+
{{- '<|im_start|>assistant\n' }}
130+
{%- endif %}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
{%- if messages[0]["role"] == "system" %}
2+
{%- set system_message = messages[0]["content"] %}
3+
{%- set loop_messages = messages[1:] %}
4+
{%- else %}
5+
{%- set loop_messages = messages %}
6+
{%- endif %}
7+
{%- if not tools is defined %}
8+
{%- set tools = none %}
9+
{%- endif %}
10+
{%- if tools is defined %}
11+
{%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a single JSON array or objects, where each object is a tool call, not as separate objects outside of an array or multiple arrays. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %}
12+
{%- if system_message is defined %}
13+
{%- set system_message = parallel_tool_prompt + "\n\n" + system_message %}
14+
{%- else %}
15+
{%- set system_message = parallel_tool_prompt %}
16+
{%- endif %}
17+
{%- endif %}
18+
{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %}
19+
20+
{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %}
21+
{%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %}
22+
{{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }}
23+
{%- endif %}
24+
{%- endfor %}
25+
26+
{{- bos_token }}
27+
{%- for message in loop_messages %}
28+
{%- if message["role"] == "user" %}
29+
{%- if tools is not none and (message == user_messages[-1]) %}
30+
{{- "[AVAILABLE_TOOLS] [" }}
31+
{%- for tool in tools %}
32+
{%- set tool = tool.function %}
33+
{{- '{"type": "function", "function": {' }}
34+
{%- for key, val in tool.items() if key != "return" %}
35+
{%- if val is string %}
36+
{{- '"' + key + '": "' + val + '"' }}
37+
{%- else %}
38+
{{- '"' + key + '": ' + val|tojson }}
39+
{%- endif %}
40+
{%- if not loop.last %}
41+
{{- ", " }}
42+
{%- endif %}
43+
{%- endfor %}
44+
{{- "}}" }}
45+
{%- if not loop.last %}
46+
{{- ", " }}
47+
{%- else %}
48+
{{- "]" }}
49+
{%- endif %}
50+
{%- endfor %}
51+
{{- "[/AVAILABLE_TOOLS]" }}
52+
{%- endif %}
53+
{%- if loop.last and system_message is defined %}
54+
{{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }}
55+
{%- else %}
56+
{{- "[INST] " + message["content"] + "[/INST]" }}
57+
{%- endif %}
58+
{%- elif message["role"] == "tool_calls" or message.tool_calls is defined %}
59+
{%- if message.tool_calls is defined %}
60+
{%- set tool_calls = message.tool_calls %}
61+
{%- else %}
62+
{%- set tool_calls = message.content %}
63+
{%- endif %}
64+
{{- "[TOOL_CALLS] [" }}
65+
{%- for tool_call in tool_calls %}
66+
{%- set out = tool_call.function|tojson %}
67+
{{- out[:-1] }}
68+
{%- if not tool_call.id is defined or tool_call.id|length < 9 %}
69+
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }}
70+
{%- endif %}
71+
{{- ', "id": "' + tool_call.id[-9:] + '"}' }}
72+
{%- if not loop.last %}
73+
{{- ", " }}
74+
{%- else %}
75+
{{- "]" + eos_token }}
76+
{%- endif %}
77+
{%- endfor %}
78+
{%- elif message["role"] == "assistant" %}
79+
{{- " " + message["content"] + eos_token }}
80+
{%- elif message["role"] == "tool_results" or message["role"] == "tool" %}
81+
{%- if message.content is defined and message.content.content is defined %}
82+
{%- set content = message.content.content %}
83+
{%- else %}
84+
{%- set content = message.content %}
85+
{%- endif %}
86+
{{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }}
87+
{%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %}
88+
{{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }}
89+
{%- endif %}
90+
{{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }}
91+
{%- else %}
92+
{{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}
93+
{%- endif %}
94+
{%- endfor %}

0 commit comments

Comments
 (0)