Skip to content

Commit 5059cd4

Browse files
authored
Tokenizers fall back to BPE if unregistered (#231)
* Tokenizers fall back to BPE if unregistered Now that the library has been tested more extensively it's not so valuable to fail on tokenizers we haven't encountered before. We still need to check what happens with tokenizers that use a different model, or those that haven't been ported from their original implementations. An alternative would be to expose a registration mechanism, as discussed in #63. * Issue warning on unregistered tokenizer * strict mode (throw by default) * fix
1 parent 719d31e commit 5059cd4

File tree

2 files changed

+51
-33
lines changed

2 files changed

+51
-33
lines changed

Sources/Tokenizers/Tokenizer.swift

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -102,37 +102,42 @@ public protocol PreTrainedTokenizerModel: TokenizingModel {
102102
struct TokenizerModel {
103103
static let knownTokenizers: [String: PreTrainedTokenizerModel.Type] = [
104104
"BertTokenizer": BertTokenizer.self,
105+
"CodeGenTokenizer": BPETokenizer.self,
106+
"CodeLlamaTokenizer": BPETokenizer.self,
107+
"CohereTokenizer": BPETokenizer.self,
105108
"DistilbertTokenizer": BertTokenizer.self,
106109
"DistilBertTokenizer": BertTokenizer.self,
110+
"FalconTokenizer": BPETokenizer.self,
111+
"GemmaTokenizer": BPETokenizer.self,
112+
"GPT2Tokenizer": BPETokenizer.self,
113+
"LlamaTokenizer": BPETokenizer.self,
107114
"RobertaTokenizer": BPETokenizer.self,
108-
"CodeGenTokenizer": CodeGenTokenizer.self,
109-
"CodeLlamaTokenizer": CodeLlamaTokenizer.self,
110-
"FalconTokenizer": FalconTokenizer.self,
111-
"GemmaTokenizer": GemmaTokenizer.self,
112-
"GPT2Tokenizer": GPT2Tokenizer.self,
113-
"LlamaTokenizer": LlamaTokenizer.self,
114115
"T5Tokenizer": T5Tokenizer.self,
115-
"WhisperTokenizer": WhisperTokenizer.self,
116-
"CohereTokenizer": CohereTokenizer.self,
117-
"Qwen2Tokenizer": Qwen2Tokenizer.self,
118116
"PreTrainedTokenizer": BPETokenizer.self,
117+
"Qwen2Tokenizer": BPETokenizer.self,
118+
"WhisperTokenizer": BPETokenizer.self,
119119
]
120120

121121
static func unknownToken(from tokenizerConfig: Config) -> String? {
122122
tokenizerConfig.unkToken.content.string() ?? tokenizerConfig.unkToken.string()
123123
}
124124

125-
static func from(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws -> TokenizingModel {
125+
static func from(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int], strict: Bool = true) throws -> TokenizingModel {
126126
guard let tokenizerClassName = tokenizerConfig.tokenizerClass.string() else {
127127
throw TokenizerError.missingTokenizerClassInConfig
128128
}
129129

130130
// Some tokenizer_class entries use a Fast suffix
131131
let tokenizerName = tokenizerClassName.replacingOccurrences(of: "Fast", with: "")
132-
guard let tokenizerClass = TokenizerModel.knownTokenizers[tokenizerName] else {
133-
throw TokenizerError.unsupportedTokenizer(tokenizerName)
132+
// Fallback to BPETokenizer if class is not explicitly registered
133+
let tokenizerClass = TokenizerModel.knownTokenizers[tokenizerName] ?? BPETokenizer.self
134+
if TokenizerModel.knownTokenizers[tokenizerName] == nil {
135+
if strict {
136+
throw TokenizerError.unsupportedTokenizer(tokenizerName)
137+
} else {
138+
print("Warning: Tokenizer model class \(tokenizerName) is not registered, falling back to a standard BPE implementation.")
139+
}
134140
}
135-
136141
return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
137142
}
138143
}
@@ -288,7 +293,7 @@ public class PreTrainedTokenizer: Tokenizer {
288293
/// Cache for compiled Jinja templates keyed by their literal template string
289294
private var compiledChatTemplateCache: [String: Template] = [:]
290295

291-
public required init(tokenizerConfig: Config, tokenizerData: Config) throws {
296+
public required init(tokenizerConfig: Config, tokenizerData: Config, strict: Bool = true) throws {
292297
var addedTokens: [String: Int] = [:]
293298
var specialTokens: [String: Int] = [:]
294299
for addedToken in tokenizerData["addedTokens"].array(or: []) {
@@ -331,7 +336,7 @@ public class PreTrainedTokenizer: Tokenizer {
331336
cleanUpTokenizationSpaces = tokenizerConfig.cleanUpTokenizationSpaces.boolean(or: true)
332337
self.tokenizerConfig = tokenizerConfig
333338

334-
model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens)
339+
model = try TokenizerModel.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, addedTokens: addedTokens, strict: strict)
335340
}
336341

337342
private func compiledTemplate(for templateString: String) throws -> Template {
@@ -615,46 +620,38 @@ public extension AutoTokenizer {
615620
return PreTrainedTokenizer.self
616621
}
617622

618-
static func from(tokenizerConfig: Config, tokenizerData: Config) throws -> Tokenizer {
623+
static func from(tokenizerConfig: Config, tokenizerData: Config, strict: Bool = true) throws -> Tokenizer {
619624
let tokenizerClass = tokenizerClass(for: tokenizerConfig)
620-
return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
625+
return try tokenizerClass.init(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, strict: strict)
621626
}
622627

623628
static func from(
624629
pretrained model: String,
625-
hubApi: HubApi = .shared
630+
hubApi: HubApi = .shared,
631+
strict: Bool = true
626632
) async throws -> Tokenizer {
627633
let config = LanguageModelConfigurationFromHub(modelName: model, hubApi: hubApi)
628634
guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig }
629635
let tokenizerData = try await config.tokenizerData
630636

631-
return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
637+
return try AutoTokenizer.from(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, strict: strict)
632638
}
633639

634640
static func from(
635641
modelFolder: URL,
636-
hubApi: HubApi = .shared
642+
hubApi: HubApi = .shared,
643+
strict: Bool = true
637644
) async throws -> Tokenizer {
638645
let config = LanguageModelConfigurationFromHub(modelFolder: modelFolder, hubApi: hubApi)
639646
guard let tokenizerConfig = try await config.tokenizerConfig else { throw TokenizerError.missingConfig }
640647
let tokenizerData = try await config.tokenizerData
641648

642-
return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData)
649+
return try PreTrainedTokenizer(tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData, strict: strict)
643650
}
644651
}
645652

646653
// MARK: - Tokenizer model classes
647654

648-
class GPT2Tokenizer: BPETokenizer { }
649-
class FalconTokenizer: BPETokenizer { }
650-
class LlamaTokenizer: BPETokenizer { }
651-
class CodeGenTokenizer: BPETokenizer { }
652-
class WhisperTokenizer: BPETokenizer { }
653-
class GemmaTokenizer: BPETokenizer { }
654-
class CodeLlamaTokenizer: BPETokenizer { }
655-
class CohereTokenizer: BPETokenizer { }
656-
class Qwen2Tokenizer: BPETokenizer { }
657-
658655
class T5Tokenizer: UnigramTokenizer { }
659656

660657
// MARK: - PreTrainedTokenizer classes
@@ -707,7 +704,7 @@ func maybeUpdatePostProcessor(tokenizerConfig: Config, processorConfig: Config?)
707704
class LlamaPreTrainedTokenizer: PreTrainedTokenizer {
708705
let isLegacy: Bool
709706

710-
required init(tokenizerConfig: Config, tokenizerData: Config) throws {
707+
required init(tokenizerConfig: Config, tokenizerData: Config, strict: Bool = true) throws {
711708
isLegacy = tokenizerConfig.legacy.boolean(or: true)
712709
var configDictionary = tokenizerData.dictionary(or: [:])
713710
if !isLegacy {
@@ -722,6 +719,6 @@ class LlamaPreTrainedTokenizer: PreTrainedTokenizer {
722719
}
723720

724721
let updatedData = Config(configDictionary)
725-
try super.init(tokenizerConfig: tokenizerConfig, tokenizerData: updatedData)
722+
try super.init(tokenizerConfig: tokenizerConfig, tokenizerData: updatedData, strict: strict)
726723
}
727724
}

Tests/TokenizersTests/TokenizerTests.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,27 @@ class PhiSimpleTests: XCTestCase {
120120
}
121121
}
122122

123+
class UnregisteredTokenizerTests: XCTestCase {
124+
func testNllbTokenizer() async throws {
125+
do {
126+
_ = try await AutoTokenizer.from(pretrained: "Xenova/nllb-200-distilled-600M")
127+
XCTFail("Expected AutoTokenizer.from to throw for strict mode")
128+
} catch {
129+
// Expected to throw in normal (strict) mode
130+
}
131+
132+
// no strict mode proceeds
133+
guard let tokenizer = try await AutoTokenizer.from(pretrained: "Xenova/nllb-200-distilled-600M", strict: false) as? PreTrainedTokenizer else {
134+
XCTFail()
135+
return
136+
}
137+
138+
let ids = tokenizer.encode(text: "Why did the chicken cross the road?")
139+
let expected = [256047, 24185, 4077, 349, 1001, 22690, 83580, 349, 82801, 248130, 2]
140+
XCTAssertEqual(ids, expected)
141+
}
142+
}
143+
123144
class LlamaPostProcessorOverrideTests: XCTestCase {
124145
/// Deepseek needs a post-processor override to add a bos token as in the reference implementation
125146
func testDeepSeek() async throws {

0 commit comments

Comments
 (0)