Skip to content

Commit 1518ed1

Browse files
committed
add general purpose function calling handler
1 parent 2d0808b commit 1518ed1

File tree

1 file changed

+343
-0
lines changed

1 file changed

+343
-0
lines changed

llama_cpp/llama_chat_format.py

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4127,3 +4127,346 @@ def chatml_function_calling(
41274127
}
41284128
chat_completion["choices"][0]["message"]["function_call"] = single_function_call
41294129
return chat_completion
4130+
4131+
4132+
@register_chat_completion_handler("gguf-function-calling")
4133+
def gguf_function_calling(
4134+
llama: llama.Llama,
4135+
messages: List[llama_types.ChatCompletionRequestMessage],
4136+
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
4137+
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
4138+
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
4139+
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
4140+
temperature: float = 0.2,
4141+
top_p: float = 0.95,
4142+
top_k: int = 40,
4143+
min_p: float = 0.05,
4144+
typical_p: float = 1.0,
4145+
stream: bool = False,
4146+
stop: Optional[Union[str, List[str]]] = [],
4147+
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
4148+
max_tokens: Optional[int] = None,
4149+
presence_penalty: float = 0.0,
4150+
frequency_penalty: float = 0.0,
4151+
repeat_penalty: float = 1.1,
4152+
tfs_z: float = 1.0,
4153+
mirostat_mode: int = 0,
4154+
mirostat_tau: float = 5.0,
4155+
mirostat_eta: float = 0.1,
4156+
model: Optional[str] = None,
4157+
logits_processor: Optional[llama.LogitsProcessorList] = None,
4158+
grammar: Optional[llama.LlamaGrammar] = None,
4159+
logprobs: Optional[bool] = None,
4160+
top_logprobs: Optional[int] = None,
4161+
**kwargs: Any,
4162+
) -> Union[
4163+
llama_types.CreateChatCompletionResponse,
4164+
Iterator[llama_types.CreateChatCompletionStreamResponse],
4165+
]:
4166+
4167+
function_calling_template = None
4168+
if hasattr(llama, 'model_path'):
4169+
from llama_cpp.llama import Llama
4170+
metadata = Llama.get_metadata(llama.model_path)
4171+
if metadata and "tokenizer.chat.template" in metadata:
4172+
function_calling_template = metadata["tokenizer.chat.template"]
4173+
4174+
4175+
function_calling_template = (
4176+
"{% for message in messages %}"
4177+
"<|im_start|>{{ message.role }}\n"
4178+
# System message
4179+
"{% if message.role == 'system' %}"
4180+
"{{ message.content }}"
4181+
"{% if tool_calls %}"
4182+
"\n\nYou have access to the following functions:\n"
4183+
"{% for tool in tools %}"
4184+
'\n{% if tool.function.get("description") %}/* {{ tool.function.description | trim }} */{% endif %}'
4185+
"\nfunctions.{{ tool.function.name }}:\n"
4186+
"{{ tool.function.parameters | tojson }}"
4187+
"\n{% endfor %}"
4188+
"\nYou must respond to user messages with either a single message or with one or more function calls."
4189+
"\n\nTo respond with a message use the following format:"
4190+
"\n\nmessage:"
4191+
"\n<message>"
4192+
"\n\nTo respond with one or more function calls use the following format:"
4193+
"\n\n<function_calls>"
4194+
"\nfunctions.<function_name>:"
4195+
'\n{ "arg1": "value1", "arg2": "value2" }'
4196+
"\nfunctions.<function_name>:"
4197+
'\n{ "arg1": "value1", "arg2": "value2" }'
4198+
"\n</function_calls>"
4199+
"{% endif %}"
4200+
"<|im_end|>\n"
4201+
"{% endif %}"
4202+
# User message
4203+
"{% if message.role == 'user' %}"
4204+
"{{ message.content }}"
4205+
"<|im_end|>\n"
4206+
"{% endif %}"
4207+
# Assistant message
4208+
"{% if message.role == 'assistant' %}"
4209+
## Regular message
4210+
"{% if message.content and message.content | length > 0 %}"
4211+
"{% if tool_calls %}"
4212+
"message:\n"
4213+
"{% endif %}"
4214+
"{{ message.content }}"
4215+
"<|im_end|>\n"
4216+
"{% endif %}"
4217+
## Function calls
4218+
"{% if 'tool_calls' in message %}"
4219+
"{% for tool_call in message.tool_calls %}"
4220+
"functions.{{ tool_call.function.name }}:\n"
4221+
"{{ tool_call.function.arguments }}"
4222+
"{% endfor %}"
4223+
"<|im_end|>\n"
4224+
"{% endif %}"
4225+
"{% endif %}"
4226+
"{% endfor %}"
4227+
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
4228+
)
4229+
template_renderer = ImmutableSandboxedEnvironment(
4230+
autoescape=jinja2.select_autoescape(["html", "xml"]),
4231+
undefined=jinja2.StrictUndefined,
4232+
).from_string(function_calling_template)
4233+
4234+
# Convert legacy functions to tools
4235+
if functions is not None:
4236+
tools = [{"type": "function", "function": function} for function in functions]
4237+
4238+
# Convert legacy function_call to tool_choice
4239+
if function_call is not None:
4240+
if isinstance(function_call, str) and (function_call in ("none", "auto")):
4241+
tool_choice = function_call
4242+
if isinstance(function_call, dict) and "name" in function_call:
4243+
tool_choice = {"type": "function", "function": {"name": function_call["name"]}}
4244+
4245+
# Collect the llama.create_completion keyword arguments so we don't have to repeat these with
4246+
# each completion call
4247+
stop = (
4248+
[stop, "<|im_end|>"]
4249+
if isinstance(stop, str)
4250+
else [*stop, "<|im_end|>"]
4251+
if stop
4252+
else ["<|im_end|>"]
4253+
)
4254+
grammar = ( # It is assumed the grammar applies to messages only, not tool calls
4255+
grammar
4256+
if grammar is not None
4257+
else (
4258+
_grammar_for_response_format(response_format)
4259+
if response_format is not None and response_format["type"] == "json_object"
4260+
else None
4261+
)
4262+
)
4263+
completion_kwargs = {
4264+
"temperature": temperature,
4265+
"top_p": top_p,
4266+
"top_k": top_k,
4267+
"min_p": min_p,
4268+
"typical_p": typical_p,
4269+
"stream": stream,
4270+
"stop": stop,
4271+
"max_tokens": max_tokens,
4272+
"presence_penalty": presence_penalty,
4273+
"frequency_penalty": frequency_penalty,
4274+
"repeat_penalty": repeat_penalty,
4275+
"tfs_z": tfs_z,
4276+
"mirostat_mode": mirostat_mode,
4277+
"mirostat_tau": mirostat_tau,
4278+
"mirostat_eta": mirostat_eta,
4279+
"model": model,
4280+
"logits_processor": logits_processor,
4281+
"grammar": grammar,
4282+
}
4283+
4284+
# Case 1: No tool use
4285+
if (
4286+
tool_choice is None
4287+
or (isinstance(tool_choice, str) and tool_choice == "none")
4288+
or tools is None
4289+
or len(tools) == 0
4290+
):
4291+
prompt = template_renderer.render(
4292+
messages=messages, tools=[], tool_calls=None, add_generation_prompt=True
4293+
)
4294+
return _convert_completion_to_chat(
4295+
llama.create_completion(
4296+
prompt=prompt,
4297+
**completion_kwargs, # type: ignore[arg-type]
4298+
logprobs=top_logprobs if logprobs else None,
4299+
),
4300+
stream=stream,
4301+
)
4302+
4303+
# Ensure there is a system prompt to attach the tool metadata to
4304+
if not any(message["role"] == "system" for message in messages):
4305+
messages = [*messages, {"role": "system", "content": ""}]
4306+
4307+
# Case 2: Automatic or fixed tool choice
4308+
# Case 2 step 1: Determine whether to respond with a message or a tool call
4309+
assert (isinstance(tool_choice, str) and tool_choice == "auto") or isinstance(tool_choice, dict)
4310+
if isinstance(tool_choice, dict):
4311+
tools = [t for t in tools if t["function"]["name"] == tool_choice["function"]["name"]]
4312+
assert tools
4313+
function_names = " | ".join([f'''"functions.{t['function']['name']}:"''' for t in tools])
4314+
prompt = template_renderer.render(
4315+
messages=messages, tools=tools, tool_calls=True, add_generation_prompt=True
4316+
)
4317+
initial_gbnf_tool_grammar = (
4318+
(
4319+
'root ::= "<function_calls>" "\\n" functions | "message:"\n'
4320+
f"functions ::= {function_names}\n"
4321+
)
4322+
if tool_choice == "auto"
4323+
else f'root ::= "<function_calls>" "\\n" functions\nfunctions ::= {function_names}\n'
4324+
)
4325+
completion = cast(
4326+
llama_types.CreateCompletionResponse,
4327+
llama.create_completion(
4328+
prompt=prompt,
4329+
**{ # type: ignore[arg-type]
4330+
**completion_kwargs,
4331+
"temperature": 0,
4332+
"stream": False,
4333+
"stop": [":"],
4334+
"max_tokens": None,
4335+
"grammar": llama_grammar.LlamaGrammar.from_string(
4336+
initial_gbnf_tool_grammar, verbose=llama.verbose
4337+
),
4338+
},
4339+
),
4340+
)
4341+
text = completion["choices"][0]["text"]
4342+
tool_name = None if text.startswith("message") else text.split("\n")[-1][len("functions.") :]
4343+
4344+
# Case 2 step 2A: Respond with a message
4345+
if tool_name is None:
4346+
prompt = template_renderer.render(
4347+
messages=messages, tools=[], tool_calls=None, add_generation_prompt=True
4348+
)
4349+
return _convert_completion_to_chat(
4350+
llama.create_completion(
4351+
prompt=prompt,
4352+
**completion_kwargs, # type: ignore[arg-type]
4353+
logprobs=top_logprobs if logprobs else None,
4354+
),
4355+
stream=stream,
4356+
)
4357+
4358+
# Case 2 step 2B: One or more function calls
4359+
follow_up_gbnf_tool_grammar = (
4360+
'root ::= functions | "</function_calls>" | "<|im_end|>"\n'
4361+
f"functions ::= {function_names}\n"
4362+
)
4363+
prompt += "<function_calls>\n"
4364+
if stream:
4365+
return _stream_tool_calls(
4366+
llama, prompt, tools, tool_name, completion_kwargs, follow_up_gbnf_tool_grammar
4367+
)
4368+
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
4369+
completions: List[llama_types.CreateCompletionResponse] = []
4370+
completions_tool_name: List[str] = []
4371+
while tool is not None and len(completions) <= 16:
4372+
# Generate the parameter values for the selected tool
4373+
prompt += f"functions.{tool_name}:\n"
4374+
try:
4375+
grammar = llama_grammar.LlamaGrammar.from_json_schema(
4376+
json.dumps(tool["function"]["parameters"]), verbose=llama.verbose
4377+
)
4378+
except Exception as e:
4379+
warnings.warn(
4380+
f"Failed to parse function body as JSON schema, falling back to default grammar\n\n{e}",
4381+
category=RuntimeWarning,
4382+
stacklevel=2,
4383+
)
4384+
grammar = llama_grammar.LlamaGrammar.from_string(
4385+
llama_grammar.JSON_GBNF, verbose=llama.verbose
4386+
)
4387+
completion_or_chunks = llama.create_completion(
4388+
prompt=prompt,
4389+
**{ # type: ignore[arg-type]
4390+
**completion_kwargs,
4391+
"max_tokens": None,
4392+
"grammar": grammar,
4393+
},
4394+
)
4395+
completion = cast(llama_types.CreateCompletionResponse, completion_or_chunks)
4396+
completions.append(completion)
4397+
completions_tool_name.append(tool_name)
4398+
prompt += completion["choices"][0]["text"]
4399+
prompt += "\n"
4400+
# Determine whether to call another tool or stop
4401+
response = cast(
4402+
llama_types.CreateCompletionResponse,
4403+
llama.create_completion(
4404+
prompt=prompt,
4405+
**{ # type: ignore[arg-type]
4406+
**completion_kwargs,
4407+
"temperature": 0,
4408+
"stream": False,
4409+
"stop": [*completion_kwargs["stop"], ":", "</function_calls>"], # type: ignore[misc]
4410+
"max_tokens": None,
4411+
"grammar": llama_grammar.LlamaGrammar.from_string(
4412+
follow_up_gbnf_tool_grammar, verbose=llama.verbose
4413+
),
4414+
},
4415+
),
4416+
)
4417+
tool_name = response["choices"][0]["text"][len("functions.") :]
4418+
tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None)
4419+
# Merge the completions into a single chat completion
4420+
chat_completion: llama_types.CreateChatCompletionResponse = {
4421+
"id": "chat" + completion["id"],
4422+
"object": "chat.completion",
4423+
"created": completion["created"],
4424+
"model": completion["model"],
4425+
"choices": [
4426+
{
4427+
"finish_reason": "tool_calls",
4428+
"index": 0,
4429+
"logprobs": _convert_text_completion_logprobs_to_chat(
4430+
completion["choices"][0]["logprobs"]
4431+
),
4432+
"message": {
4433+
"role": "assistant",
4434+
"content": None,
4435+
"tool_calls": [
4436+
{
4437+
"id": "call_" + f"_{i}_" + tool_name + "_" + completion["id"],
4438+
"type": "function",
4439+
"function": {
4440+
"name": tool_name,
4441+
"arguments": completion["choices"][0]["text"],
4442+
},
4443+
}
4444+
for i, (tool_name, completion) in enumerate(
4445+
zip(completions_tool_name, completions)
4446+
)
4447+
],
4448+
},
4449+
}
4450+
],
4451+
"usage": {
4452+
"completion_tokens": sum(
4453+
(completion["usage"]["completion_tokens"] if "usage" in completion else 0)
4454+
for completion in completions
4455+
),
4456+
"prompt_tokens": sum(
4457+
completion["usage"]["prompt_tokens"] if "usage" in completion else 0
4458+
for completion in completions
4459+
),
4460+
"total_tokens": sum(
4461+
completion["usage"]["total_tokens"] if "usage" in completion else 0
4462+
for completion in completions
4463+
),
4464+
},
4465+
}
4466+
if len(completions) == 1:
4467+
single_function_call: llama_types.ChatCompletionResponseFunctionCall = {
4468+
"name": tool_name,
4469+
"arguments": completions[0]["choices"][0]["text"],
4470+
}
4471+
chat_completion["choices"][0]["message"]["function_call"] = single_function_call
4472+
return chat_completion

0 commit comments

Comments
 (0)