Skip to content

Commit 48f08fb

Browse files
authored
Conditionally import CoreML and Accelerate frameworks (#247)
* Change Math type from a struct to an enum * Conditionally import CoreML and Accelerate frameworks
1 parent 69de809 commit 48f08fb

File tree

11 files changed

+31
-3
lines changed

11 files changed

+31
-3
lines changed

Sources/Generation/Generation.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
// Created by Pedro Cuenca on 7/5/23.
66
//
77

8+
#if canImport(CoreML)
89
import CoreML
10+
911
import Tokenizers
1012

1113
public enum GenerationMode {
@@ -106,3 +108,4 @@ public extension Generation {
106108
return logitsWarpers
107109
}
108110
}
111+
#endif // canImport(CoreML)

Sources/Generation/LogitsWarper/TopKLogitsWarper.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#if canImport(Accelerate)
12
import Accelerate
23
import Foundation
34

@@ -56,3 +57,4 @@ public struct TopKLogitsWarper: LogitsWarper {
5657
return (indices: topkIndices.map { indices[Int($0)] }, logits: topkLogits)
5758
}
5859
}
60+
#endif // canImport(Accelerate)

Sources/Generation/MLMultiArray+Utils.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// Copyright © 2019 Hugging Face. All rights reserved.
77
//
88

9+
#if canImport(CoreML)
910
import CoreML
1011
import Foundation
1112

@@ -196,3 +197,4 @@ extension MLMultiArray {
196197
return s + "]"
197198
}
198199
}
200+
#endif // canImport(CoreML)

Sources/Generation/MLShapedArray+Utils.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// Created by Pedro Cuenca on 13/5/23.
66
//
77

8+
#if canImport(CoreML)
89
import CoreML
910

1011
public extension MLShapedArray<Float> {
@@ -50,3 +51,4 @@ public extension MLMultiArray {
5051
}
5152
}
5253
}
54+
#endif // canImport(CoreML)

Sources/Generation/Math.swift

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
// Copyright © 2019 Hugging Face. All rights reserved.
77
//
88

9+
#if canImport(CoreML) && canImport(Accelerate)
910
import Accelerate
1011
import CoreML
1112
import Foundation
@@ -15,7 +16,7 @@ import Foundation
1516
///
1617
/// https://github.com/hollance/CoreMLHelpers
1718
///
18-
public struct Math {
19+
public enum Math {
1920
/**
2021
Returns the index and value of the largest element in the array.
2122

@@ -168,3 +169,4 @@ public extension Math {
168169
}
169170
}
170171
}
172+
#endif // canImport(CoreML) && canImport(Accelerate)

Sources/Models/LanguageModel.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
// Created by Pedro Cuenca on 7/5/23.
66
//
77

8+
#if canImport(CoreML)
89
import CoreML
10+
911
import Generation
1012
import Hub
1113
import Tokenizers
@@ -226,3 +228,5 @@ public enum TokenizerError: LocalizedError {
226228
}
227229
}
228230
}
231+
232+
#endif // canImport(CoreML)

Sources/Models/LanguageModelTypes.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
// Created by Pedro Cuenca on 8/5/23.
66
//
77

8+
#if canImport(CoreML)
89
import CoreML
10+
911
import Generation
1012
import Tokenizers
1113

@@ -40,3 +42,4 @@ public extension TextGenerationModel {
4042
try await generate(config: config, prompt: prompt, model: callAsFunction, tokenizer: tokenizer, callback: callback)
4143
}
4244
}
45+
#endif // canImport(CoreML)

Sources/Models/Weights.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#if canImport(CoreML)
12
import CoreML
23

34
public struct Weights {
@@ -91,3 +92,4 @@ struct Safetensor {
9192
return Weights(dict)
9293
}
9394
}
95+
#endif // canImport(CoreML)

Sources/TransformersCLI/main.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#if canImport(CoreML)
12
import ArgumentParser
23
import CoreML
34
import Foundation
@@ -107,3 +108,4 @@ extension Double {
107108
String(format: "\(format)", self)
108109
}
109110
}
111+
#endif // canImport(CoreML)

Tests/GenerationTests/LogitsWarperTests.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
// Created by Jan Krukowski on 09/12/2023.
55
//
66

7+
#if canImport(CoreML)
78
import CoreML
8-
@testable import Generation
99
import XCTest
1010

11+
@testable import Generation
12+
1113
final class LogitsWarperTests: XCTestCase {
1214
private let accuracy: Float = 0.00001
1315

@@ -150,3 +152,4 @@ final class LogitsWarperTests: XCTestCase {
150152
XCTAssertEqual(result5.logits, [4.5, 3.0, 2.0, 1.0], accuracy: accuracy)
151153
}
152154
}
155+
#endif // canImport(CoreML)

0 commit comments

Comments
 (0)