Skip to content

Commit 22a16bd

Browse files
committed
add general purpose function calling handler
1 parent 4cf4b15 commit 22a16bd

File tree

1 file changed

+343
-0
lines changed

1 file changed

+343
-0
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4118,3 +4118,346 @@ def chatml_function_calling(
41184118
}
41194119
chat_completion["choices"][0]["message"]["function_call"] = single_function_call
41204120
return chat_completion
4121+
4122+
4123+
@register_chat_completion_handler("gguf-function-calling")
4124+
def gguf_function_calling(
4125+
llama: llama.Llama,
4126+
messages: List[llama_types.ChatCompletionRequestMessage],
4127+
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
4128+
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
4129+
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
4130+
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
4131+
temperature: float = 0.2,
4132+
top_p: float = 0.95,
4133+
top_k: int = 40,
4134+
min_p: float = 0.05,
4135+
typical_p: float = 1.0,
4136+
stream: bool = False,
4137+
stop: Optional[Union[str, List[str]]] = [],
4138+
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
4139+
max_tokens: Optional[int] = None,
4140+
presence_penalty: float = 0.0,
4141+
frequency_penalty: float = 0.0,
4142+
repeat_penalty: float = 1.1,
4143+
tfs_z: float = 1.0,
4144+
mirostat_mode: int = 0,
4145+
mirostat_tau: float = 5.0,
4146+
mirostat_eta: float = 0.1,
4147+
model: Optional[str] = None,
4148+
logits_processor: Optional[llama.LogitsProcessorList] = None,
4149+
grammar: Optional[llama.LlamaGrammar] = None,
4150+
logprobs: Optional[bool] = None,
4151+
top_logprobs: Optional[int] = None,
4152+
**kwargs: Any,
4153+
) -> Union[
4154+
llama_types.CreateChatCompletionResponse,
4155+
Iterator[llama_types.CreateChatCompletionStreamResponse],
4156+
]:
4157+
4158+
function_calling_template = None
4159+
if hasattr(llama, 'model_path'):
4160+
from llama_cpp.llama import Llama
4161+
metadata = Llama.get_metadata(llama.model_path)
4162+
if metadata and "tokenizer.chat.template" in metadata:
4163+
function_calling_template = metadata["tokenizer.chat.template"]
4164+
4165+
4166+
function_calling_template = (
4167+
"{% for message in messages %}"
4168+
"<|im_start|>{{ message.role }}\n"
4169+
# System message
4170+
"{% if message.role == 'system' %}"
4171+
"{{ message.content }}"
4172+
"{% if tool_calls %}"
4173+
"\n\nYou have access to the following functions:\n"
4174+
"{% for tool in tools %}"
4175+
'\n{% if tool.function.get("description") %}/* {{ tool.function.description | trim }} */{% endif %}'
4176+
"\nfunctions.{{ tool.function.name }}:\n"
4177+
"{{ tool.function.parameters | tojson }}"
4178+
"\n{% endfor %}"
4179+
"\nYou must respond to user messages with either a single message or with one or more function calls."
4180+
"\n\nTo respond with a message use the following format:"
4181+
"\n\nmessage:"
4182+
"\n<message>"
4183+
"\n\nTo respond with one or more function calls use the following format:"
4184+
"\n\n<function_calls>"
4185+
"\nfunctions.<function_name>:"
4186+
'\n{ "arg1": "value1", "arg2": "value2" }'
4187+
"\nfunctions.<function_name>:"
4188+
'\n{ "arg1": "value1", "arg2": "value2" }'
4189+
"\n</function_calls>"
4190+
"{% endif %}"
4191+
"<|im_end|>\n"
4192+
"{% endif %}"
4193+
# User message
4194+
"{% if message.role == 'user' %}"
4195+
"{{ message.content }}"
4196+
"<|im_end|>\n"
4197+
"{% endif %}"
4198+
# Assistant message
4199+
"{% if message.role == 'assistant' %}"
4200+
## Regular message
4201+
"{% if message.content and message.content | length > 0 %}"
4202+
"{% if tool_calls %}"
4203+
"message:\n"
4204+
"{% endif %}"
4205+
"{{ message.content }}"
4206+
"<|im_end|>\n"
4207+
"{% endif %}"
4208+
## Function calls
4209+
"{% if 'tool_calls' in message %}"
4210+
"{% for tool_call in message.tool_calls %}"
4211+
"functions.{{ tool_call.function.name }}:\n"
4212+
"{{ tool_call.function.arguments }}"
4213+
"{% endfor %}"
4214+
"<|im_end|>\n"
4215+
"{% endif %}"
4216+
"{% endif %}"
4217+
"{% endfor %}"
4218+
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
4219+
)
4220+
template_renderer = ImmutableSandboxedEnvironment(
4221+
autoescape=jinja2.select_autoescape(["html", "xml"]),
4222+
undefined=jinja2.StrictUndefined,
4223+
).from_string(function_calling_template)
4224+
4225+
# Convert legacy functions to tools
4226+
if functions is not None:
4227+
tools = [{"type": "function", "function": function} for function in functions]
4228+
4229+
# Convert legacy function_call to tool_choice
4230+
if function_call is not None:
4231+
if isinstance(function_call, str) and (function_call in ("none", "auto")):
4232+
tool_choice = function_call
4233+
if isinstance(function_call, dict) and "name" in function_call:
4234+
tool_choice = {"type": "function", "function": {"name": function_call["name"]}}
4235+
4236+
# Collect the llama.create_completion keyword arguments so we don't have to repeat these with
4237+
# each completion call
4238+
stop = (
4239+
[stop, "<|im_end|>"]
4240+
if isinstance(stop, str)
4241+
else [*stop, "<|im_end|>"]
4242+
if stop
4243+
else ["<|im_end|>"]
4244+
)
4245+
grammar = ( # It is assumed the grammar applies to messages only, not tool calls
4246+
grammar
4247+
if grammar is not None
4248+
else (
4249+
_grammar_for_response_format(response_format)
4250+
if response_format is not None and response_format["type"] == "json_object"
4251+
else None
4252+
)
4253+
)
4254+
completion_kwargs = {
4255+
"temperature": temperature,
4256+
"top_p": top_p,
4257+
"top_k": top_k,
4258+
"min_p": min_p,
4259+
"typical_p": typical_p,
4260+
"stream": stream,
4261+
"stop": stop,
4262+
"max_tokens": max_tokens,
4263+
"presence_penalty": presence_penalty,
4264+
"frequency_penalty": frequency_penalty,
4265+
"repeat_penalty": repeat_penalty,
4266+
"tfs_z": tfs_z,
4267+
"mirostat_mode": mirostat_mode,
4268+
"mirostat_tau": mirostat_tau,
4269+
"mirostat_eta": mirostat_eta,
4270+
"model": model,
4271+
"logits_processor": logits_processor,
4272+
"grammar": grammar,
4273+
}
4274+
4275+
# Case 1: No tool use
4276+
if (
4277+
tool_choice is None
4278+
or (isinstance(tool_choice, str) and tool_choice == "none")
4279+
or tools is None
4280+
or len(tools) == 0
4281+
):
4282+
prompt = template_renderer.render(
4283+
messages=messages, tools=[], tool_calls=None, add_generation_prompt=True
4284+
)
4285+
return _convert_completion_to_chat(
4286+
llama.create_completion(
4287+
prompt=prompt,
4288+
**completion_kwargs, # type: ignore[arg-type]
4289+
logprobs=top_logprobs if logprobs else None,
4290+
),
4291+
stream=stream,
4292+
)
4293+
4294+
# Ensure there is a system prompt to attach the tool metadata to
4295+
if not any(message["role"] == "system" for message in messages):
4296+
messages = [*messages, {"role": "system", "content": ""}]
4297+
4298+
# Case 2: Automatic or fixed tool choice
4299+
# Case 2 step 1: Determine whether to respond with a message or a tool call
4300+
assert (isinstance(tool_choice, str) and tool_choice == "auto") or isinstance(tool_choice, dict)
4301+
if isinstance(tool_choice, dict):
4302+
tools = [t for t in tools if t["function"]["name"] == tool_choice["function"]["name"]]
4303+
assert tools
4304+
function_names = " | ".join([f'''"functions.{t['function']['name']}:"''' for t in tools])
4305+
prompt = template_renderer.render(
4306+
messages=messages, tools=tools, tool_calls=True, add_generation_prompt=True
4307+
)
4308+
initial_gbnf_tool_grammar = (
4309+
(
4310+
'root ::= "<function_calls>" "\\n" functions | "message:"\n'
4311+
f"functions ::= {function_names}\n"
4312+
)
4313+
if tool_choice == "auto"
4314+
else f'root ::= "<function_calls>" "\\n" functions\nfunctions ::= {function_names}\n'
4315+
)
4316+
completion = cast(
4317+
llama_types.CreateCompletionResponse,
4318+
llama.create_completion(
4319+
prompt=prompt,
4320+
**{ # type: ignore[arg-type]
4321+
**completion_kwargs,
4322+
"temperature": 0,
4323+
"stream": False,
4324+
"stop": [":"],
4325+
"max_tokens": None,
4326+
"grammar": llama_grammar.LlamaGrammar.from_string(
4327+
initial_gbnf_tool_grammar, verbose=llama.verbose
4328+
),
4329+
},
4330+
),
4331+
)
4332+
text = completion["choices"][0]["text"]
4333+
tool_name = None if text.startswith("message") else text.split("\n")[-1][len("functions.") :]
4334+
4335+
# Case 2 step 2A: Respond with a message
4336+
if tool_name is None:
4337+
prompt = template_renderer.render(
4338+
messages=messages, tools=[], tool_calls=None, add_generation_prompt=True
4339+
)
4340+
return _convert_completion_to_chat(
4341+
llama.create_completion(
4342+
prompt=prompt,
4343+
**completion_kwargs, # type: ignore[arg-type]
4344+
logprobs=top_logprobs if logprobs else None,
4345+
),
4346+
stream=stream,
4347+
)
4348+
4349+
# Case 2 step 2B: One or more function calls
4350+
follow_up_gbnf_tool_grammar = (
4351+
'root ::= functions | "</function_calls>" | "<|im_end|>"\n'
4352+
f"functions ::= {function_names}\n"
4353+
)
4354+
prompt += "<function_calls>\n"
4355+
if stream:
4356+
return _stream_tool_calls(
4357+
llama, prompt, tools, tool_name, completion_kwargs, follow_up_gbnf_tool_grammar
4358+
)
4359+
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
4360+
completions: List[llama_types.CreateCompletionResponse] = []
4361+
completions_tool_name: List[str] = []
4362+
while tool is not None and len(completions) <= 16:
4363+
# Generate the parameter values for the selected tool
4364+
prompt += f"functions.{tool_name}:\n"
4365+
try:
4366+
grammar = llama_grammar.LlamaGrammar.from_json_schema(
4367+
json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
4368+
)
4369+
except Exception as e:
4370+
warnings.warn(
4371+
f"Failed to parse function body as JSON schema, falling back to default grammar\n\n{e}",
4372+
category=RuntimeWarning,
4373+
stacklevel=2,
4374+
)
4375+
grammar = llama_grammar.LlamaGrammar.from_string(
4376+
llama_grammar.JSON_GBNF, verbose=llama.verbose
4377+
)
4378+
completion_or_chunks = llama.create_completion(
4379+
prompt=prompt,
4380+
**{ # type: ignore[arg-type]
4381+
**completion_kwargs,
4382+
"max_tokens": None,
4383+
"grammar": grammar,
4384+
},
4385+
)
4386+
completion = cast(llama_types.CreateCompletionResponse, completion_or_chunks)
4387+
completions.append(completion)
4388+
completions_tool_name.append(tool_name)
4389+
prompt += completion["choices"][0]["text"]
4390+
prompt += "\n"
4391+
# Determine whether to call another tool or stop
4392+
response = cast(
4393+
llama_types.CreateCompletionResponse,
4394+
llama.create_completion(
4395+
prompt=prompt,
4396+
**{ # type: ignore[arg-type]
4397+
**completion_kwargs,
4398+
"temperature": 0,
4399+
"stream": False,
4400+
"stop": [*completion_kwargs["stop"], ":", "</function_calls>"], # type: ignore[misc]
4401+
"max_tokens": None,
4402+
"grammar": llama_grammar.LlamaGrammar.from_string(
4403+
follow_up_gbnf_tool_grammar, verbose=llama.verbose
4404+
),
4405+
},
4406+
),
4407+
)
4408+
tool_name = response["choices"][0]["text"][len("functions.") :]
4409+
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
4410+
# Merge the completions into a single chat completion
4411+
chat_completion: llama_types.CreateChatCompletionResponse = {
4412+
"id": "chat" + completion["id"],
4413+
"object": "chat.completion",
4414+
"created": completion["created"],
4415+
"model": completion["model"],
4416+
"choices": [
4417+
{
4418+
"finish_reason": "tool_calls",
4419+
"index": 0,
4420+
"logprobs": _convert_text_completion_logprobs_to_chat(
4421+
completion["choices"][0]["logprobs"]
4422+
),
4423+
"message": {
4424+
"role": "assistant",
4425+
"content": None,
4426+
"tool_calls": [
4427+
{
4428+
"id": "call_" + f"_{i}_" + tool_name + "_" + completion["id"],
4429+
"type": "function",
4430+
"function": {
4431+
"name": tool_name,
4432+
"arguments": completion["choices"][0]["text"],
4433+
},
4434+
}
4435+
for i, (tool_name, completion) in enumerate(
4436+
zip(completions_tool_name, completions)
4437+
)
4438+
],
4439+
},
4440+
}
4441+
],
4442+
"usage": {
4443+
"completion_tokens": sum(
4444+
(completion["usage"]["completion_tokens"] if "usage" in completion else 0)
4445+
for completion in completions
4446+
),
4447+
"prompt_tokens": sum(
4448+
completion["usage"]["prompt_tokens"] if "usage" in completion else 0
4449+
for completion in completions
4450+
),
4451+
"total_tokens": sum(
4452+
completion["usage"]["total_tokens"] if "usage" in completion else 0
4453+
for completion in completions
4454+
),
4455+
},
4456+
}
4457+
if len(completions) == 1:
4458+
single_function_call: llama_types.ChatCompletionResponseFunctionCall = {
4459+
"name": tool_name,
4460+
"arguments": completions[0]["choices"][0]["text"],
4461+
}
4462+
chat_completion["choices"][0]["message"]["function_call"] = single_function_call
4463+
return chat_completion

0 commit comments

Comments
 (0)