@@ -18,6 +18,7 @@ use crate::infer::{Infer, InferError};
1818use pyo3:: prelude:: * ;
1919use pyo3:: types:: IntoPyDict ;
2020use serde:: { Deserialize , Serialize } ;
21+ use std:: collections:: HashMap ;
2122use tokenizers:: Encoding ;
2223use tracing:: warn;
2324use utoipa:: ToSchema ;
@@ -912,7 +913,10 @@ pub(crate) struct ChatRequest {
912913}
913914
914915impl ChatRequest {
915- fn try_into_generate ( self , infer : & Infer ) -> Result < ( GenerateRequest , bool ) , InferError > {
916+ fn try_into_generate (
917+ self ,
918+ infer : & Infer ,
919+ ) -> Result < ( GenerateRequest , Option < HashMap < String , String > > ) , InferError > {
916920 let ChatRequest {
917921 model,
918922 max_tokens,
@@ -952,7 +956,7 @@ impl ChatRequest {
952956 let ( inputs, grammar, using_tools) = match response_format {
953957 Some ( format) => {
954958 let inputs = infer. apply_chat_template ( messages, None ) ?;
955- ( inputs, Some ( format) , false )
959+ ( inputs, Some ( format) , None )
956960 }
957961 None => {
958962 if let Some ( tools) = tools {
@@ -961,20 +965,31 @@ impl ChatRequest {
961965 let grammar = GrammarType :: Json ( serde_json:: json!( tool_schema) ) ;
962966 let inputs: String = infer. apply_chat_template (
963967 messages,
964- Some ( ( updated_tools, tool_prompt) ) ,
968+ Some ( ( updated_tools. clone ( ) , tool_prompt) ) ,
965969 ) ?;
966- ( inputs, Some ( grammar) , true )
970+ let tool_name_to_id: HashMap < String , String > = updated_tools
971+ . into_iter ( )
972+ . map ( |tool| {
973+ (
974+ tool. function . name ,
975+ tool. function
976+ . id
977+ . map_or_else ( || "0" . to_string ( ) , |id| id. to_string ( ) ) ,
978+ )
979+ } )
980+ . collect ( ) ;
981+ ( inputs, Some ( grammar) , Some ( tool_name_to_id) )
967982 }
968983 None => {
969984 // same as if no response_format or tools are set
970985 let inputs = infer. apply_chat_template ( messages, None ) ?;
971- ( inputs, None , false )
986+ ( inputs, None , None )
972987 }
973988 }
974989 } else {
975990 // if no response_format or tools are set simply apply the chat template to generate inputs
976991 let inputs = infer. apply_chat_template ( messages, None ) ?;
977- ( inputs, None , false )
992+ ( inputs, None , None )
978993 }
979994 }
980995 } ;
@@ -1154,6 +1169,8 @@ pub struct FunctionDefinition {
11541169 #[ serde( default ) ]
11551170 pub description : Option < String > ,
11561171 pub name : String ,
1172+ #[ serde( default , skip_serializing_if = "Option::is_none" ) ]
1173+ pub id : Option < String > ,
11571174 #[ serde( alias = "parameters" , serialize_with = "serialize_as_string" ) ]
11581175 pub arguments : serde_json:: Value ,
11591176}
@@ -1175,7 +1192,7 @@ pub(crate) struct Tool {
11751192 pub function : FunctionDefinition ,
11761193}
11771194
1178- #[ derive( Clone , Serialize , Deserialize , Default ) ]
1195+ #[ derive( Debug , Clone , Serialize , Deserialize , Default ) ]
11791196pub ( crate ) struct ChatTemplateInputs < ' a > {
11801197 messages : Vec < TextMessage > ,
11811198 bos_token : Option < & ' a str > ,
@@ -1208,6 +1225,9 @@ pub enum MessageChunk {
12081225pub struct Message {
12091226 #[ schema( example = "user" ) ]
12101227 pub role : String ,
1228+ #[ serde( default , skip_serializing_if = "Option::is_none" ) ]
1229+ #[ schema( example = "10" ) ]
1230+ pub tool_call_id : Option < String > ,
12111231 #[ serde( flatten) ]
12121232 #[ schema( example = "My name is David and I" ) ]
12131233 pub body : MessageBody ,
@@ -1287,7 +1307,7 @@ impl From<Message> for TextMessage {
12871307 . collect :: < Vec < _ > > ( )
12881308 . join ( "" ) ,
12891309 } ,
1290- .. Default :: default ( )
1310+ tool_call_id : value . tool_call_id ,
12911311 }
12921312 }
12931313}
@@ -1758,6 +1778,7 @@ mod tests {
17581778 id: "0" . to_string( ) ,
17591779 r#type: "function" . to_string( ) ,
17601780 function: FunctionDefinition {
1781+ id: None ,
17611782 description: None ,
17621783 name: "myfn" . to_string( ) ,
17631784 arguments: json!( {
0 commit comments