Skip to content

Commit f9471b6

Browse files
committed
add general purpose function calling handler
1 parent 646beb7 commit f9471b6

File tree

1 file changed

+343
-0
lines changed

1 file changed

+343
-0
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3980,3 +3980,346 @@ def chatml_function_calling(
39803980
}
39813981
chat_completion["choices"][0]["message"]["function_call"] = single_function_call
39823982
return chat_completion
3983+
3984+
3985+
@register_chat_completion_handler("gguf-function-calling")
3986+
def gguf_function_calling(
3987+
llama: llama.Llama,
3988+
messages: List[llama_types.ChatCompletionRequestMessage],
3989+
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
3990+
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
3991+
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
3992+
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
3993+
temperature: float = 0.2,
3994+
top_p: float = 0.95,
3995+
top_k: int = 40,
3996+
min_p: float = 0.05,
3997+
typical_p: float = 1.0,
3998+
stream: bool = False,
3999+
stop: Optional[Union[str, List[str]]] = [],
4000+
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
4001+
max_tokens: Optional[int] = None,
4002+
presence_penalty: float = 0.0,
4003+
frequency_penalty: float = 0.0,
4004+
repeat_penalty: float = 1.1,
4005+
tfs_z: float = 1.0,
4006+
mirostat_mode: int = 0,
4007+
mirostat_tau: float = 5.0,
4008+
mirostat_eta: float = 0.1,
4009+
model: Optional[str] = None,
4010+
logits_processor: Optional[llama.LogitsProcessorList] = None,
4011+
grammar: Optional[llama.LlamaGrammar] = None,
4012+
logprobs: Optional[bool] = None,
4013+
top_logprobs: Optional[int] = None,
4014+
**kwargs: Any,
4015+
) -> Union[
4016+
llama_types.CreateChatCompletionResponse,
4017+
Iterator[llama_types.CreateChatCompletionStreamResponse],
4018+
]:
4019+
4020+
function_calling_template = None
4021+
if hasattr(llama, 'model_path'):
4022+
from llama_cpp.llama import Llama
4023+
metadata = Llama.get_metadata(llama.model_path)
4024+
if metadata and "tokenizer.chat.template" in metadata:
4025+
function_calling_template = metadata["tokenizer.chat.template"]
4026+
4027+
4028+
function_calling_template = (
4029+
"{% for message in messages %}"
4030+
"<|im_start|>{{ message.role }}\n"
4031+
# System message
4032+
"{% if message.role == 'system' %}"
4033+
"{{ message.content }}"
4034+
"{% if tool_calls %}"
4035+
"\n\nYou have access to the following functions:\n"
4036+
"{% for tool in tools %}"
4037+
'\n{% if tool.function.get("description") %}/* {{ tool.function.description | trim }} */{% endif %}'
4038+
"\nfunctions.{{ tool.function.name }}:\n"
4039+
"{{ tool.function.parameters | tojson }}"
4040+
"\n{% endfor %}"
4041+
"\nYou must respond to user messages with either a single message or with one or more function calls."
4042+
"\n\nTo respond with a message use the following format:"
4043+
"\n\nmessage:"
4044+
"\n<message>"
4045+
"\n\nTo respond with one or more function calls use the following format:"
4046+
"\n\n<function_calls>"
4047+
"\nfunctions.<function_name>:"
4048+
'\n{ "arg1": "value1", "arg2": "value2" }'
4049+
"\nfunctions.<function_name>:"
4050+
'\n{ "arg1": "value1", "arg2": "value2" }'
4051+
"\n</function_calls>"
4052+
"{% endif %}"
4053+
"<|im_end|>\n"
4054+
"{% endif %}"
4055+
# User message
4056+
"{% if message.role == 'user' %}"
4057+
"{{ message.content }}"
4058+
"<|im_end|>\n"
4059+
"{% endif %}"
4060+
# Assistant message
4061+
"{% if message.role == 'assistant' %}"
4062+
## Regular message
4063+
"{% if message.content and message.content | length > 0 %}"
4064+
"{% if tool_calls %}"
4065+
"message:\n"
4066+
"{% endif %}"
4067+
"{{ message.content }}"
4068+
"<|im_end|>\n"
4069+
"{% endif %}"
4070+
## Function calls
4071+
"{% if 'tool_calls' in message %}"
4072+
"{% for tool_call in message.tool_calls %}"
4073+
"functions.{{ tool_call.function.name }}:\n"
4074+
"{{ tool_call.function.arguments }}"
4075+
"{% endfor %}"
4076+
"<|im_end|>\n"
4077+
"{% endif %}"
4078+
"{% endif %}"
4079+
"{% endfor %}"
4080+
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
4081+
)
4082+
template_renderer = ImmutableSandboxedEnvironment(
4083+
autoescape=jinja2.select_autoescape(["html", "xml"]),
4084+
undefined=jinja2.StrictUndefined,
4085+
).from_string(function_calling_template)
4086+
4087+
# Convert legacy functions to tools
4088+
if functions is not None:
4089+
tools = [{"type": "function", "function": function} for function in functions]
4090+
4091+
# Convert legacy function_call to tool_choice
4092+
if function_call is not None:
4093+
if isinstance(function_call, str) and (function_call in ("none", "auto")):
4094+
tool_choice = function_call
4095+
if isinstance(function_call, dict) and "name" in function_call:
4096+
tool_choice = {"type": "function", "function": {"name": function_call["name"]}}
4097+
4098+
# Collect the llama.create_completion keyword arguments so we don't have to repeat these with
4099+
# each completion call
4100+
stop = (
4101+
[stop, "<|im_end|>"]
4102+
if isinstance(stop, str)
4103+
else [*stop, "<|im_end|>"]
4104+
if stop
4105+
else ["<|im_end|>"]
4106+
)
4107+
grammar = ( # It is assumed the grammar applies to messages only, not tool calls
4108+
grammar
4109+
if grammar is not None
4110+
else (
4111+
_grammar_for_response_format(response_format)
4112+
if response_format is not None and response_format["type"] == "json_object"
4113+
else None
4114+
)
4115+
)
4116+
completion_kwargs = {
4117+
"temperature": temperature,
4118+
"top_p": top_p,
4119+
"top_k": top_k,
4120+
"min_p": min_p,
4121+
"typical_p": typical_p,
4122+
"stream": stream,
4123+
"stop": stop,
4124+
"max_tokens": max_tokens,
4125+
"presence_penalty": presence_penalty,
4126+
"frequency_penalty": frequency_penalty,
4127+
"repeat_penalty": repeat_penalty,
4128+
"tfs_z": tfs_z,
4129+
"mirostat_mode": mirostat_mode,
4130+
"mirostat_tau": mirostat_tau,
4131+
"mirostat_eta": mirostat_eta,
4132+
"model": model,
4133+
"logits_processor": logits_processor,
4134+
"grammar": grammar,
4135+
}
4136+
4137+
# Case 1: No tool use
4138+
if (
4139+
tool_choice is None
4140+
or (isinstance(tool_choice, str) and tool_choice == "none")
4141+
or tools is None
4142+
or len(tools) == 0
4143+
):
4144+
prompt = template_renderer.render(
4145+
messages=messages, tools=[], tool_calls=None, add_generation_prompt=True
4146+
)
4147+
return _convert_completion_to_chat(
4148+
llama.create_completion(
4149+
prompt=prompt,
4150+
**completion_kwargs, # type: ignore[arg-type]
4151+
logprobs=top_logprobs if logprobs else None,
4152+
),
4153+
stream=stream,
4154+
)
4155+
4156+
# Ensure there is a system prompt to attach the tool metadata to
4157+
if not any(message["role"] == "system" for message in messages):
4158+
messages = [*messages, {"role": "system", "content": ""}]
4159+
4160+
# Case 2: Automatic or fixed tool choice
4161+
# Case 2 step 1: Determine whether to respond with a message or a tool call
4162+
assert (isinstance(tool_choice, str) and tool_choice == "auto") or isinstance(tool_choice, dict)
4163+
if isinstance(tool_choice, dict):
4164+
tools = [t for t in tools if t["function"]["name"] == tool_choice["function"]["name"]]
4165+
assert tools
4166+
function_names = " | ".join([f'''"functions.{t['function']['name']}:"''' for t in tools])
4167+
prompt = template_renderer.render(
4168+
messages=messages, tools=tools, tool_calls=True, add_generation_prompt=True
4169+
)
4170+
initial_gbnf_tool_grammar = (
4171+
(
4172+
'root ::= "<function_calls>" "\\n" functions | "message:"\n'
4173+
f"functions ::= {function_names}\n"
4174+
)
4175+
if tool_choice == "auto"
4176+
else f'root ::= "<function_calls>" "\\n" functions\nfunctions ::= {function_names}\n'
4177+
)
4178+
completion = cast(
4179+
llama_types.CreateCompletionResponse,
4180+
llama.create_completion(
4181+
prompt=prompt,
4182+
**{ # type: ignore[arg-type]
4183+
**completion_kwargs,
4184+
"temperature": 0,
4185+
"stream": False,
4186+
"stop": [":"],
4187+
"max_tokens": None,
4188+
"grammar": llama_grammar.LlamaGrammar.from_string(
4189+
initial_gbnf_tool_grammar, verbose=llama.verbose
4190+
),
4191+
},
4192+
),
4193+
)
4194+
text = completion["choices"][0]["text"]
4195+
tool_name = None if text.startswith("message") else text.split("\n")[-1][len("functions.") :]
4196+
4197+
# Case 2 step 2A: Respond with a message
4198+
if tool_name is None:
4199+
prompt = template_renderer.render(
4200+
messages=messages, tools=[], tool_calls=None, add_generation_prompt=True
4201+
)
4202+
return _convert_completion_to_chat(
4203+
llama.create_completion(
4204+
prompt=prompt,
4205+
**completion_kwargs, # type: ignore[arg-type]
4206+
logprobs=top_logprobs if logprobs else None,
4207+
),
4208+
stream=stream,
4209+
)
4210+
4211+
# Case 2 step 2B: One or more function calls
4212+
follow_up_gbnf_tool_grammar = (
4213+
'root ::= functions | "</function_calls>" | "<|im_end|>"\n'
4214+
f"functions ::= {function_names}\n"
4215+
)
4216+
prompt += "<function_calls>\n"
4217+
if stream:
4218+
return _stream_tool_calls(
4219+
llama, prompt, tools, tool_name, completion_kwargs, follow_up_gbnf_tool_grammar
4220+
)
4221+
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
4222+
completions: List[llama_types.CreateCompletionResponse] = []
4223+
completions_tool_name: List[str] = []
4224+
while tool is not None and len(completions) <= 16:
4225+
# Generate the parameter values for the selected tool
4226+
prompt += f"functions.{tool_name}:\n"
4227+
try:
4228+
grammar = llama_grammar.LlamaGrammar.from_json_schema(
4229+
json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
4230+
)
4231+
except Exception as e:
4232+
warnings.warn(
4233+
f"Failed to parse function body as JSON schema, falling back to default grammar\n\n{e}",
4234+
category=RuntimeWarning,
4235+
stacklevel=2,
4236+
)
4237+
grammar = llama_grammar.LlamaGrammar.from_string(
4238+
llama_grammar.JSON_GBNF, verbose=llama.verbose
4239+
)
4240+
completion_or_chunks = llama.create_completion(
4241+
prompt=prompt,
4242+
**{ # type: ignore[arg-type]
4243+
**completion_kwargs,
4244+
"max_tokens": None,
4245+
"grammar": grammar,
4246+
},
4247+
)
4248+
completion = cast(llama_types.CreateCompletionResponse, completion_or_chunks)
4249+
completions.append(completion)
4250+
completions_tool_name.append(tool_name)
4251+
prompt += completion["choices"][0]["text"]
4252+
prompt += "\n"
4253+
# Determine whether to call another tool or stop
4254+
response = cast(
4255+
llama_types.CreateCompletionResponse,
4256+
llama.create_completion(
4257+
prompt=prompt,
4258+
**{ # type: ignore[arg-type]
4259+
**completion_kwargs,
4260+
"temperature": 0,
4261+
"stream": False,
4262+
"stop": [*completion_kwargs["stop"], ":", "</function_calls>"], # type: ignore[misc]
4263+
"max_tokens": None,
4264+
"grammar": llama_grammar.LlamaGrammar.from_string(
4265+
follow_up_gbnf_tool_grammar, verbose=llama.verbose
4266+
),
4267+
},
4268+
),
4269+
)
4270+
tool_name = response["choices"][0]["text"][len("functions.") :]
4271+
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
4272+
# Merge the completions into a single chat completion
4273+
chat_completion: llama_types.CreateChatCompletionResponse = {
4274+
"id": "chat" + completion["id"],
4275+
"object": "chat.completion",
4276+
"created": completion["created"],
4277+
"model": completion["model"],
4278+
"choices": [
4279+
{
4280+
"finish_reason": "tool_calls",
4281+
"index": 0,
4282+
"logprobs": _convert_text_completion_logprobs_to_chat(
4283+
completion["choices"][0]["logprobs"]
4284+
),
4285+
"message": {
4286+
"role": "assistant",
4287+
"content": None,
4288+
"tool_calls": [
4289+
{
4290+
"id": "call_" + f"_{i}_" + tool_name + "_" + completion["id"],
4291+
"type": "function",
4292+
"function": {
4293+
"name": tool_name,
4294+
"arguments": completion["choices"][0]["text"],
4295+
},
4296+
}
4297+
for i, (tool_name, completion) in enumerate(
4298+
zip(completions_tool_name, completions)
4299+
)
4300+
],
4301+
},
4302+
}
4303+
],
4304+
"usage": {
4305+
"completion_tokens": sum(
4306+
(completion["usage"]["completion_tokens"] if "usage" in completion else 0)
4307+
for completion in completions
4308+
),
4309+
"prompt_tokens": sum(
4310+
completion["usage"]["prompt_tokens"] if "usage" in completion else 0
4311+
for completion in completions
4312+
),
4313+
"total_tokens": sum(
4314+
completion["usage"]["total_tokens"] if "usage" in completion else 0
4315+
for completion in completions
4316+
),
4317+
},
4318+
}
4319+
if len(completions) == 1:
4320+
single_function_call: llama_types.ChatCompletionResponseFunctionCall = {
4321+
"name": tool_name,
4322+
"arguments": completions[0]["choices"][0]["text"],
4323+
}
4324+
chat_completion["choices"][0]["message"]["function_call"] = single_function_call
4325+
return chat_completion

0 commit comments

Comments
 (0)