@@ -9,13 +9,13 @@ import Hub
99import Foundation
1010import Jinja
1111
12- enum TokenizerError : Error {
12+ enum TokenizerError : Error {
1313 case missingConfig
1414 case missingTokenizerClassInConfig
1515 case unsupportedTokenizer( String )
1616 case missingVocab
1717 case malformedVocab
18-
18+ case chatTemplate ( String )
1919 case tooLong( String )
2020}
2121
@@ -94,6 +94,13 @@ struct TokenizerModel {
9494 }
9595}
9696
97+ public enum ChatTemplateArgument {
98+ /// A Jinja template to use for the conversation. Normally it is not necessary to provide a template, since it will be read from the tokenizer config.
99+ case literal( String )
100+ /// For models whose tokenizer config includes multiple chat templates, the template can be specified by name. Normally this is not necessary.
101+ case name( String )
102+ }
103+
97104public protocol Tokenizer {
98105 func tokenize( text: String ) -> [ String ]
99106
@@ -117,15 +124,24 @@ public protocol Tokenizer {
117124 var eosTokenId : Int ? { get }
118125 var unknownToken : String ? { get }
119126 var unknownTokenId : Int ? { get }
120-
127+
128+ /// The appropriate chat template is selected from the tokenizer config
121129 func applyChatTemplate( messages: [ [ String : String ] ] ) throws -> [ Int ]
122-
130+
131+ /// The chat template is provided as a string literal or specified by name
132+ func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ]
133+
134+ /// The chat template is provided as a string literal
135+ func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: String ) throws -> [ Int ]
136+
123137 func applyChatTemplate(
124138 messages: [ [ String : String ] ] ,
125- chatTemplate: String ? ,
139+ /// A chat template can optionally be provided or specified by name when several templates are included in the tokenizer config. Normally this is not necessary.
140+ chatTemplate: ChatTemplateArgument ? ,
126141 addGenerationPrompt: Bool ,
127142 truncation: Bool ,
128- maxLength: Int ?
143+ maxLength: Int ? ,
144+ tools: [ [ String : Any ] ] ?
129145 ) throws -> [ Int ]
130146}
131147
@@ -176,8 +192,6 @@ public class PreTrainedTokenizer: Tokenizer {
176192 private let tokenizerConfig : Config
177193
178194 private let cleanUpTokenizationSpaces : Bool
179-
180- private let defaultChatTemplate : String = " {% for message in messages %}{{'<|im_start|>' + message['role'] + ' \n ' + message['content'] + '<|im_end|>' + ' \n '}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant \n ' }}{% endif %} "
181195
182196 required public init ( tokenizerConfig: Config , tokenizerData: Config ) throws {
183197 var addedTokens : [ String : Int ] = [ : ]
@@ -222,7 +236,7 @@ public class PreTrainedTokenizer: Tokenizer {
222236 self . decoder = DecoderFactory . fromConfig ( config: tokenizerData. decoder)
223237 self . cleanUpTokenizationSpaces = tokenizerConfig. cleanUpTokenizationSpaces? . boolValue ?? true
224238 self . tokenizerConfig = tokenizerConfig
225-
239+
226240 model = try TokenizerModel . from ( tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
227241 }
228242
@@ -316,22 +330,76 @@ public class PreTrainedTokenizer: Tokenizer {
316330 public func convertIdToToken( _ id: Int ) -> String ? {
317331 model. convertIdToToken ( id)
318332 }
319-
333+
320334 public func applyChatTemplate( messages: [ [ String : String ] ] ) throws -> [ Int ] {
321- try applyChatTemplate ( messages: messages, chatTemplate: nil , addGenerationPrompt: true , maxLength: nil )
335+ try applyChatTemplate ( messages: messages, addGenerationPrompt: true )
336+ }
337+
338+ public func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: ChatTemplateArgument ) throws -> [ Int ] {
339+ try applyChatTemplate ( messages: messages, chatTemplate: chatTemplate, addGenerationPrompt: true )
322340 }
323-
341+
342+ public func applyChatTemplate( messages: [ [ String : String ] ] , chatTemplate: String ) throws -> [ Int ] {
343+ try applyChatTemplate ( messages: messages, chatTemplate: . literal( chatTemplate) , addGenerationPrompt: true )
344+ }
345+
324346 public func applyChatTemplate(
325347 messages: [ [ String : String ] ] ,
326- chatTemplate: String ? ,
348+ chatTemplate: ChatTemplateArgument ? = nil ,
327349 addGenerationPrompt: Bool = false ,
328350 truncation: Bool = false ,
329- maxLength: Int ?
351+ maxLength: Int ? = nil ,
352+ /// A list of tools (callable functions) that will be accessible to the model. If the template does not
353+ /// support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
354+ /// giving the name, description and argument types for the tool. See the
355+ /// [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
356+ /// for more information.
357+ /// Note: tool calling is not supported yet, it will be available in a future update.
358+ tools: [ [ String : Any ] ] ? = nil
330359 ) throws -> [ Int ] {
331- let template = try Template ( chatTemplate ?? tokenizerConfig. chatTemplate? . stringValue ?? defaultChatTemplate)
360+ var selectedChatTemplate : String ?
361+ if let chatTemplate, case . literal( let template) = chatTemplate {
362+ // Use chat template from argument
363+ selectedChatTemplate = template
364+ } else if let valueFromConfig = tokenizerConfig. chatTemplate {
365+ if let arrayValue = valueFromConfig. arrayValue {
366+ // If the config specifies a list of chat templates, convert them to a dictionary
367+ let templateDict = Dictionary < String , String > ( uniqueKeysWithValues: arrayValue. compactMap { item in
368+ guard let name = item. name? . stringValue, let template = item. template? . stringValue else {
369+ return nil
370+ }
371+ return ( name, template)
372+ } )
373+ if let chatTemplate, case . name( let name) = chatTemplate {
374+ // Select chat template from config by name
375+ if let matchingDictEntry = templateDict [ name] {
376+ selectedChatTemplate = matchingDictEntry
377+ } else {
378+ throw TokenizerError . chatTemplate ( " No chat template named \" \( name) \" was found in the tokenizer config " )
379+ }
380+ } else if let tools, !tools. isEmpty, let toolUseTemplate = templateDict [ " tool_use " ] {
381+ // Use tool use chat template from config
382+ selectedChatTemplate = toolUseTemplate
383+ } else if let defaultChatTemplate = templateDict [ " default " ] {
384+ // Use default chat template from config
385+ selectedChatTemplate = defaultChatTemplate
386+ }
387+ } else if let stringValue = valueFromConfig. stringValue {
388+ // Use chat template from config
389+ selectedChatTemplate = stringValue
390+ }
391+ }
392+
393+ guard let selectedChatTemplate else {
394+ throw TokenizerError . chatTemplate ( " No chat template was specified " )
395+ }
396+
397+ let template = try Template ( selectedChatTemplate)
332398 var context : [ String : Any ] = [
333399 " messages " : messages,
334400 " add_generation_prompt " : addGenerationPrompt
401+ // TODO: Add `tools` entry when support is added in Jinja
402+ // "tools": tools
335403 ]
336404
337405 // TODO: maybe keep NSString here
@@ -397,15 +465,15 @@ extension AutoTokenizer {
397465
398466 return try AutoTokenizer . from ( tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
399467 }
400-
468+
401469 public static func from(
402470 modelFolder: URL ,
403471 hubApi: HubApi = . shared
404472 ) async throws -> Tokenizer {
405473 let config = LanguageModelConfigurationFromHub ( modelFolder: modelFolder, hubApi: hubApi)
406474 guard let tokenizerConfig = try await config. tokenizerConfig else { throw TokenizerError . missingConfig }
407475 let tokenizerData = try await config. tokenizerData
408-
476+
409477 return try PreTrainedTokenizer ( tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
410478 }
411479}
0 commit comments