Skip to content

Commit b2bc56e

Browse files
authored
Tokenizers don't need config.json (#261)
1 parent a0d99b6 commit b2bc56e

File tree

3 files changed

+20
-14
lines changed

3 files changed

+20
-14
lines changed

Sources/Hub/Hub.swift

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ public extension Hub {
7070

7171
public final class LanguageModelConfigurationFromHub: Sendable {
7272
struct Configurations {
73-
var modelConfig: Config
73+
var modelConfig: Config?
7474
var tokenizerConfig: Config?
7575
var tokenizerData: Config
7676
}
@@ -96,7 +96,7 @@ public final class LanguageModelConfigurationFromHub: Sendable {
9696
}
9797
}
9898

99-
public var modelConfig: Config {
99+
public var modelConfig: Config? {
100100
get async throws {
101101
try await configPromise.value.modelConfig
102102
}
@@ -135,7 +135,7 @@ public final class LanguageModelConfigurationFromHub: Sendable {
135135

136136
public var modelType: String? {
137137
get async throws {
138-
try await modelConfig.modelType.string()
138+
try await modelConfig?.modelType.string()
139139
}
140140
}
141141

@@ -174,11 +174,11 @@ public final class LanguageModelConfigurationFromHub: Sendable {
174174
do {
175175
// Load required configurations
176176
let modelConfigURL = modelFolder.appending(path: "config.json")
177-
guard FileManager.default.fileExists(atPath: modelConfigURL.path) else {
178-
throw Hub.HubClientError.configurationMissing("config.json")
179-
}
180177

181-
let modelConfig = try hubApi.configuration(fileURL: modelConfigURL)
178+
var modelConfig: Config? = nil
179+
if FileManager.default.fileExists(atPath: modelConfigURL.path) {
180+
modelConfig = try hubApi.configuration(fileURL: modelConfigURL)
181+
}
182182

183183
let tokenizerDataURL = modelFolder.appending(path: "tokenizer.json")
184184
guard FileManager.default.fileExists(atPath: tokenizerDataURL.path) else {

Sources/Models/LanguageModel.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ public extension LanguageModel {
141141

142142
/// async properties downloaded from the configuration
143143
public extension LanguageModel {
144-
var modelConfig: Config {
144+
var modelConfig: Config? {
145145
get async throws {
146146
try await configuration!.modelConfig
147147
}
@@ -161,13 +161,13 @@ public extension LanguageModel {
161161

162162
var modelType: String? {
163163
get async throws {
164-
try await modelConfig.modelType.string()
164+
try await modelConfig?.modelType.string()
165165
}
166166
}
167167

168168
var textGenerationParameters: Config? {
169169
get async throws {
170-
try await modelConfig.taskSpecificParams.textGeneration
170+
try await modelConfig?.taskSpecificParams.textGeneration
171171
}
172172
}
173173

@@ -180,14 +180,14 @@ public extension LanguageModel {
180180
var bosTokenId: Int? {
181181
get async throws {
182182
let modelConfig = try await modelConfig
183-
return modelConfig.bosTokenId.integer()
183+
return modelConfig?.bosTokenId.integer()
184184
}
185185
}
186186

187187
var eosTokenId: Int? {
188188
get async throws {
189189
let modelConfig = try await modelConfig
190-
return modelConfig.eosTokenId.integer()
190+
return modelConfig?.eosTokenId.integer()
191191
}
192192
}
193193

Tests/HubTests/HubTests.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ class HubTests: XCTestCase {
2929
func testConfigDownload() async {
3030
do {
3131
let configLoader = LanguageModelConfigurationFromHub(modelName: "t5-base", hubApi: hubApi)
32-
let config = try await configLoader.modelConfig
32+
guard let config = try await configLoader.modelConfig else {
33+
XCTFail("Test repo is expected to have a config.json file")
34+
return
35+
}
3336

3437
// Test leaf value (Int)
3538
guard let eos = config["eos_token_id"].integer() else {
@@ -71,7 +74,10 @@ class HubTests: XCTestCase {
7174
func testConfigCamelCase() async {
7275
do {
7376
let configLoader = LanguageModelConfigurationFromHub(modelName: "t5-base", hubApi: hubApi)
74-
let config = try await configLoader.modelConfig
77+
guard let config = try await configLoader.modelConfig else {
78+
XCTFail("Test repo is expected to have a config.json file")
79+
return
80+
}
7581

7682
// Test leaf value (Int)
7783
guard let eos = config["eosTokenId"].integer() else {

0 commit comments

Comments
 (0)