Skip to content

Commit 3978f56

Browse files
committed
feat: align function id with tool call response
1 parent f91434e commit 3978f56

File tree

6 files changed

+63
-28
lines changed

6 files changed

+63
-28
lines changed

router/src/chat.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ pub(crate) fn parse_output(generated_text: &str) -> Result<ChatChoice, InferErro
4949
id: "0".to_string(),
5050
r#type: "function".to_string(),
5151
function: FunctionDefinition {
52+
id: None,
5253
description: None,
5354
name: name.to_string(),
5455
arguments: serde_json::to_value(call.function.arguments).map_err(|err| {

router/src/infer/chat_template.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,21 @@ impl ChatTemplate {
9696

9797
let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();
9898
let final_message = messages.last().cloned();
99+
let template_inputs = ChatTemplateInputs {
100+
messages,
101+
bos_token: self.bos_token.as_deref(),
102+
eos_token: self.eos_token.as_deref(),
103+
add_generation_prompt: true,
104+
tools,
105+
};
106+
107+
// NOTE: initalizing `template_inputs` is helpful when JSON dumping the
108+
// `ChatTemplateInputs` struct for debugging
109+
// let template_inputs_as_json = serde_json::to_string(&template_inputs).unwrap();
110+
99111
let mut rendered_template = self
100112
.template
101-
.render(ChatTemplateInputs {
102-
messages,
103-
bos_token: self.bos_token.as_deref(),
104-
eos_token: self.eos_token.as_deref(),
105-
add_generation_prompt: true,
106-
tools,
107-
})
113+
.render(template_inputs)
108114
.map_err(InferError::TemplateError)?;
109115

110116
// if the last message is from the assistant, continue the generation prompt

router/src/infer/tool_grammar.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ impl ToolGrammar {
3434
.chain(std::iter::once(Tool {
3535
r#type: "function".to_string(),
3636
function: FunctionDefinition {
37+
id: None,
3738
name: "no_tool".to_string(),
3839
description: Some(
3940
"Open ended response with no specific tool selected".to_string(),

router/src/lib.rs

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use crate::infer::{Infer, InferError};
1818
use pyo3::prelude::*;
1919
use pyo3::types::IntoPyDict;
2020
use serde::{Deserialize, Serialize};
21+
use std::collections::HashMap;
2122
use tokenizers::Encoding;
2223
use tracing::warn;
2324
use utoipa::ToSchema;
@@ -912,7 +913,10 @@ pub(crate) struct ChatRequest {
912913
}
913914

914915
impl ChatRequest {
915-
fn try_into_generate(self, infer: &Infer) -> Result<(GenerateRequest, bool), InferError> {
916+
fn try_into_generate(
917+
self,
918+
infer: &Infer,
919+
) -> Result<(GenerateRequest, Option<HashMap<String, String>>), InferError> {
916920
let ChatRequest {
917921
model,
918922
max_tokens,
@@ -952,7 +956,7 @@ impl ChatRequest {
952956
let (inputs, grammar, using_tools) = match response_format {
953957
Some(format) => {
954958
let inputs = infer.apply_chat_template(messages, None)?;
955-
(inputs, Some(format), false)
959+
(inputs, Some(format), None)
956960
}
957961
None => {
958962
if let Some(tools) = tools {
@@ -961,20 +965,31 @@ impl ChatRequest {
961965
let grammar = GrammarType::Json(serde_json::json!(tool_schema));
962966
let inputs: String = infer.apply_chat_template(
963967
messages,
964-
Some((updated_tools, tool_prompt)),
968+
Some((updated_tools.clone(), tool_prompt)),
965969
)?;
966-
(inputs, Some(grammar), true)
970+
let tool_name_to_id: HashMap<String, String> = updated_tools
971+
.into_iter()
972+
.map(|tool| {
973+
(
974+
tool.function.name,
975+
tool.function
976+
.id
977+
.map_or_else(|| "0".to_string(), |id| id.to_string()),
978+
)
979+
})
980+
.collect();
981+
(inputs, Some(grammar), Some(tool_name_to_id))
967982
}
968983
None => {
969984
// same as if no response_format or tools are set
970985
let inputs = infer.apply_chat_template(messages, None)?;
971-
(inputs, None, false)
986+
(inputs, None, None)
972987
}
973988
}
974989
} else {
975990
// if no response_format or tools are set simply apply the chat template to generate inputs
976991
let inputs = infer.apply_chat_template(messages, None)?;
977-
(inputs, None, false)
992+
(inputs, None, None)
978993
}
979994
}
980995
};
@@ -1154,6 +1169,8 @@ pub struct FunctionDefinition {
11541169
#[serde(default)]
11551170
pub description: Option<String>,
11561171
pub name: String,
1172+
#[serde(default, skip_serializing_if = "Option::is_none")]
1173+
pub id: Option<String>,
11571174
#[serde(alias = "parameters", serialize_with = "serialize_as_string")]
11581175
pub arguments: serde_json::Value,
11591176
}
@@ -1175,7 +1192,7 @@ pub(crate) struct Tool {
11751192
pub function: FunctionDefinition,
11761193
}
11771194

1178-
#[derive(Clone, Serialize, Deserialize, Default)]
1195+
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11791196
pub(crate) struct ChatTemplateInputs<'a> {
11801197
messages: Vec<TextMessage>,
11811198
bos_token: Option<&'a str>,
@@ -1208,6 +1225,9 @@ pub enum MessageChunk {
12081225
pub struct Message {
12091226
#[schema(example = "user")]
12101227
pub role: String,
1228+
#[serde(default, skip_serializing_if = "Option::is_none")]
1229+
#[schema(example = "10")]
1230+
pub tool_call_id: Option<String>,
12111231
#[serde(flatten)]
12121232
#[schema(example = "My name is David and I")]
12131233
pub body: MessageBody,
@@ -1287,7 +1307,7 @@ impl From<Message> for TextMessage {
12871307
.collect::<Vec<_>>()
12881308
.join(""),
12891309
},
1290-
..Default::default()
1310+
tool_call_id: value.tool_call_id,
12911311
}
12921312
}
12931313
}
@@ -1758,6 +1778,7 @@ mod tests {
17581778
id: "0".to_string(),
17591779
r#type: "function".to_string(),
17601780
function: FunctionDefinition {
1781+
id: None,
17611782
description: None,
17621783
name: "myfn".to_string(),
17631784
arguments: json!({

router/src/server.rs

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,8 +1165,7 @@ pub(crate) async fn chat_completions(
11651165

11661166
tracing::debug!("Got chat_template {:?}", infer.chat_template);
11671167
let id = chat.next_tool_call_id();
1168-
let (generate_request, using_tools): (GenerateRequest, bool) =
1169-
chat.clone().try_into_generate(&infer)?;
1168+
let (generate_request, using_tools) = chat.clone().try_into_generate(&infer)?;
11701169
span.record("parameters", format!("{:?}", generate_request.parameters));
11711170
let logprobs = logprobs.unwrap_or_default();
11721171

@@ -1188,7 +1187,7 @@ pub(crate) async fn chat_completions(
11881187

11891188
let response_stream = async_stream::stream! {
11901189
let mut response_stream = Box::pin(response_stream);
1191-
let mut state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
1190+
let mut state = ChatState::new(using_tools.is_some(), stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
11921191
while let Some(result) = response_stream.next().await {
11931192
match result{
11941193
Ok(stream_token) => {
@@ -1197,12 +1196,12 @@ pub(crate) async fn chat_completions(
11971196
ChatEvent::NoTool => {
11981197
chat.tools = None;
11991198
chat.response_format = None;
1200-
let (generate_request, using_tools): (GenerateRequest, bool) =
1199+
let (generate_request, using_tools) =
12011200
chat.clone().try_into_generate(&infer).unwrap();
1202-
assert!(!using_tools);
1201+
assert!(using_tools.is_none());
12031202
let (_headers, response_stream2) =
12041203
generate_stream_internal(infer.clone(), compute_type.clone(), Json(generate_request), span.clone()).await;
1205-
state = ChatState::new(using_tools, stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
1204+
state = ChatState::new(using_tools.is_some(), stream_options.clone(), system_fingerprint.clone(), model_id.clone(), logprobs, id.clone());
12061205
response_stream = Box::pin(response_stream2);
12071206
}
12081207
ChatEvent::Events(events) => {
@@ -1237,14 +1236,13 @@ pub(crate) async fn chat_completions(
12371236
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
12381237
.as_secs();
12391238

1240-
let (tool_calls, output) = if using_tools {
1239+
let (tool_calls, output) = if using_tools.is_some() {
12411240
match crate::chat::parse_output(&generation.generated_text)? {
12421241
ChatChoice::NoTool => {
12431242
chat.tools = None;
12441243
chat.response_format = None;
1245-
let (generate_request, using_tools): (GenerateRequest, bool) =
1246-
chat.clone().try_into_generate(&infer)?;
1247-
assert!(!using_tools);
1244+
let (generate_request, using_tools) = chat.clone().try_into_generate(&infer)?;
1245+
assert!(using_tools.is_none());
12481246
let (headers_final, input_length_final, Json(generation)) = generate_internal(
12491247
Extension(infer),
12501248
compute_type,
@@ -1256,7 +1254,16 @@ pub(crate) async fn chat_completions(
12561254
input_length = input_length_final;
12571255
(None, Some(generation.generated_text))
12581256
}
1259-
ChatChoice::ToolCalls(tool_calls) => (Some(tool_calls), None),
1257+
ChatChoice::ToolCalls(mut tool_calls) => {
1258+
// assign the tool ids based on the tool names
1259+
tool_calls.iter_mut().for_each(|tool_call| {
1260+
tool_call.id = using_tools
1261+
.as_ref()
1262+
.and_then(|tools| tools.get(&tool_call.function.name))
1263+
.map_or("0".to_string(), |id| id.clone());
1264+
});
1265+
(Some(tool_calls), None)
1266+
}
12601267
}
12611268
} else {
12621269
(None, Some(generation.generated_text))

router/src/vertex.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ pub(crate) async fn vertex_compatibility(
104104
},
105105
},
106106
VertexInstance::Chat(instance) => {
107-
let (generate_request, _using_tools): (GenerateRequest, bool) =
108-
instance.try_into_generate(&infer)?;
107+
let (generate_request, _using_tools) = instance.try_into_generate(&infer)?;
109108
generate_request
110109
}
111110
};

0 commit comments

Comments
 (0)