Skip to content

Commit ed54f3c

Browse files
tpaepcuenca
andauthored
updated applyChatTemplate with lazy memoization (#218)
* updated applyChatTemplate with lazy memoization * formatted code * Reused shared phi tokenizer --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
1 parent 5ba776a commit ed54f3c

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

Sources/Tokenizers/Tokenizer.swift

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,9 @@ public class PreTrainedTokenizer: Tokenizer {
285285

286286
private let cleanUpTokenizationSpaces: Bool
287287

288+
/// Cache for compiled Jinja templates keyed by their literal template string
289+
private var compiledChatTemplateCache: [String: Template] = [:]
290+
288291
public required init(tokenizerConfig: Config, tokenizerData: Config) throws {
289292
var addedTokens: [String: Int] = [:]
290293
var specialTokens: [String: Int] = [:]
@@ -332,6 +335,15 @@ public class PreTrainedTokenizer: Tokenizer {
332335
model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
333336
}
334337

338+
private func compiledTemplate(for templateString: String) throws -> Template {
339+
if let cached = compiledChatTemplateCache[templateString] {
340+
return cached
341+
}
342+
let compiled = try Template(templateString)
343+
compiledChatTemplateCache[templateString] = compiled
344+
return compiled
345+
}
346+
335347
func preTokenize(_ text: String, options: PreTokenizerOptions) -> [String] {
336348
guard let preTokenizer else { return [text] }
337349
return preTokenizer(text: text, options: options)
@@ -530,7 +542,7 @@ public class PreTrainedTokenizer: Tokenizer {
530542
throw TokenizerError.missingChatTemplate
531543
}
532544

533-
let template = try Template(selectedChatTemplate)
545+
let template = try compiledTemplate(for: selectedChatTemplate)
534546
var context: [String: Any] = [
535547
"messages": messages,
536548
"add_generation_prompt": addGenerationPrompt,

Tests/TokenizersTests/ChatTemplateTests.swift

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// Created by Anthony DePasquale on 2/10/24.
66
//
77

8+
import Foundation
89
import Tokenizers
910
import XCTest
1011

@@ -277,4 +278,32 @@ class ChatTemplateTests: XCTestCase {
277278
}
278279
}
279280
}
281+
282+
/// Performance: cached vs uncached template application
283+
func testApplyChatTemplatePerformanceCached() async throws {
284+
let tokenizer = try await Self.sharedPhiTokenizer()
285+
286+
// Purposely reuse the same template literal to hit the memoized compiled template
287+
let mistral7BDefaultTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
288+
289+
// Prime cache once
290+
_ = try tokenizer.applyChatTemplate(messages: messages, chatTemplate: mistral7BDefaultTemplate)
291+
292+
measure(metrics: [XCTClockMetric()]) {
293+
_ = try! tokenizer.applyChatTemplate(messages: messages, chatTemplate: mistral7BDefaultTemplate)
294+
}
295+
}
296+
297+
/// Performance: simulate uncached runs by varying the template to bypass memoization
298+
func testApplyChatTemplatePerformanceUncached() async throws {
299+
let tokenizer = try await Self.sharedPhiTokenizer()
300+
301+
let baseTemplate = "{{bos_token}}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
302+
303+
measure(metrics: [XCTClockMetric()]) {
304+
// Make the template string unique each iteration to force a fresh compilation
305+
let uniqueTemplate = baseTemplate + "{# perf \(UUID().uuidString) #}"
306+
_ = try! tokenizer.applyChatTemplate(messages: messages, chatTemplate: uniqueTemplate)
307+
}
308+
}
280309
}

0 commit comments

Comments
 (0)