Skip to content

Commit 3f1a93f

Browse files
authored
Reduce access level for extensions on imported types (#250)
* Change extension on Array<String> providing matching(glob:) to private access level * Change extension on FileManager providing getFileUrls(at:) to private access level * Change extension on String providing trimmingFromStart/End to private access level * Change extension on String for splitting to internal access level * Consolidate string splitting tests into PreTokenizerTests * Rename Hub/BOMDoubling.swift to Hub/Extensions/JSONSerialization+BOM.swift * Change extension on Character providing isExtendedPunctuation to private access level * Extract String extensions to separate file * Fixup string+Extensions * Change extension on CoreML types to internal access level Extract and consolidate CoreML extensions into separate file * @testable import Tokenizers to access internal extensions * Rename String+Extensions.swift to String+PreTokenization.swift
1 parent c5de635 commit 3f1a93f

File tree

12 files changed

+282
-336
lines changed

12 files changed

+282
-336
lines changed

Sources/Generation/MLMultiArray+Utils.swift renamed to Sources/Generation/CoreML+Extensions.swift

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//
2-
// MLMultiArray+Utils.swift
2+
// CoreML+Extensions.swift
33
// CoreMLBert
44
//
55
// Created by Julien Chaumond on 27/06/2019.
@@ -10,7 +10,7 @@
1010
import CoreML
1111
import Foundation
1212

13-
public extension MLMultiArray {
13+
extension MLMultiArray {
1414
/// All values will be stored in the last dimension of the MLMultiArray (default is dims=1)
1515
static func from(_ arr: [Int], dims: Int = 1) -> MLMultiArray {
1616
var shape = Array(repeating: 1, count: dims)
@@ -88,7 +88,7 @@ public extension MLMultiArray {
8888
}
8989
}
9090

91-
public extension MLMultiArray {
91+
extension MLMultiArray {
9292
/// Provides a way to index n-dimensionals arrays a la numpy.
9393
enum Indexing: Equatable {
9494
case select(Int)
@@ -197,4 +197,48 @@ extension MLMultiArray {
197197
return s + "]"
198198
}
199199
}
200+
201+
extension MLShapedArray<Float> {
202+
var floats: [Float] {
203+
guard strides.first == 1, strides.count == 1 else {
204+
// For some reason this path is slow.
205+
// If strides is not 1, we can write a Metal kernel to copy the values properly.
206+
return scalars
207+
}
208+
209+
// Fast path: memcpy
210+
let mlArray = MLMultiArray(self)
211+
return mlArray.floats ?? scalars
212+
}
213+
}
214+
215+
extension MLShapedArraySlice<Float> {
216+
var floats: [Float] {
217+
guard strides.first == 1, strides.count == 1 else {
218+
// For some reason this path is slow.
219+
// If strides is not 1, we can write a Metal kernel to copy the values properly.
220+
return scalars
221+
}
222+
223+
// Fast path: memcpy
224+
let mlArray = MLMultiArray(self)
225+
return mlArray.floats ?? scalars
226+
}
227+
}
228+
229+
extension MLMultiArray {
230+
var floats: [Float]? {
231+
guard dataType == .float32 else { return nil }
232+
233+
var result: [Float] = Array(repeating: 0, count: count)
234+
return withUnsafeBytes { ptr in
235+
guard let source = ptr.baseAddress else { return nil }
236+
result.withUnsafeMutableBytes { resultPtr in
237+
let dest = resultPtr.baseAddress!
238+
memcpy(dest, source, self.count * MemoryLayout<Float>.stride)
239+
}
240+
return result
241+
}
242+
}
243+
}
200244
#endif // canImport(CoreML)

Sources/Generation/MLShapedArray+Utils.swift

Lines changed: 0 additions & 54 deletions
This file was deleted.

Sources/Hub/BOMDoubling.swift renamed to Sources/Hub/Extensions/JSONSerialization+BOM.swift

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
//
2-
// BOMDoubling.swift
2+
// JSONSerialization+BOM.swift
33
// swift-transformers
44
//
55
// Created by Pedro Cuenca on 20250912
66
//
77

88
import Foundation
99

10-
extension Data {
10+
extension JSONSerialization {
11+
class func bomPreservingJsonObject(with data: Data, options: JSONSerialization.ReadingOptions = []) throws -> Any {
12+
try JSONSerialization.jsonObject(with: data.duplicatingBOMsAfterQuotes, options: options)
13+
}
14+
}
15+
16+
private extension Data {
1117
/// Workaround for https://github.com/huggingface/swift-transformers/issues/116
1218
/// Duplicate a BOM sequence that follows a quote. The first BOM is swallowed by JSONSerialization.jsonObject
1319
/// because it thinks it marks the encoding.
@@ -40,9 +46,3 @@ extension Data {
4046
}
4147
}
4248
}
43-
44-
extension JSONSerialization {
45-
class func bomPreservingJsonObject(with data: Data, options: JSONSerialization.ReadingOptions = []) throws -> Any {
46-
try JSONSerialization.jsonObject(with: data.duplicatingBOMsAfterQuotes, options: options)
47-
}
48-
}

Sources/Hub/HubApi.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -891,13 +891,13 @@ public extension Hub {
891891
}
892892
}
893893

894-
public extension [String] {
894+
private extension [String] {
895895
func matching(glob: String) -> [String] {
896896
filter { fnmatch(glob, $0, 0) == 0 }
897897
}
898898
}
899899

900-
public extension FileManager {
900+
private extension FileManager {
901901
func getFileUrls(at directoryUrl: URL) throws -> [URL] {
902902
var fileUrls = [URL]()
903903

Sources/Tokenizers/BertTokenizer.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ class BasicTokenizer {
226226
}
227227
}
228228

229-
extension Character {
229+
private extension Character {
230230
/// https://github.com/huggingface/transformers/blob/8c1b5d37827a6691fef4b2d926f2d04fb6f5a9e3/src/transformers/tokenization_utils.py#L367
231231
var isExtendedPunctuation: Bool {
232232
if isPunctuation { return true }

Sources/Tokenizers/Decoder.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ class MetaspaceDecoder: Decoder {
236236
}
237237

238238
/// We could use firstIndex(where:), lastIndex(where:) for possibly better efficiency (and do both ends at once)
239-
public extension String {
239+
private extension String {
240240
func trimmingFromStart(character: Character = " ", upto: Int) -> String {
241241
var result = self
242242
var trimmed = 0

Sources/Tokenizers/PreTokenizer.swift

Lines changed: 0 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -238,164 +238,3 @@ class SplitPreTokenizer: PreTokenizer {
238238
return pattern.split(text, invert: invert)
239239
}
240240
}
241-
242-
enum StringSplitPattern {
243-
case regexp(regexp: String)
244-
case string(pattern: String)
245-
}
246-
247-
extension StringSplitPattern {
248-
func split(_ text: String, invert: Bool = true) -> [String] {
249-
switch self {
250-
case let .regexp(regexp):
251-
text.split(by: regexp, includeSeparators: true)
252-
case let .string(substring):
253-
text.split(by: substring, options: [], includeSeparators: !invert)
254-
}
255-
}
256-
}
257-
258-
extension StringSplitPattern {
259-
static func from(config: Config) -> StringSplitPattern? {
260-
if let pattern = config.pattern.String.string() {
261-
return StringSplitPattern.string(pattern: pattern)
262-
}
263-
if let pattern = config.pattern.Regex.string() {
264-
return StringSplitPattern.regexp(regexp: pattern)
265-
}
266-
return nil
267-
}
268-
}
269-
270-
public extension String {
271-
func ranges(of string: String, options: CompareOptions = .regularExpression) -> [Range<Index>] {
272-
var result: [Range<Index>] = []
273-
var start = startIndex
274-
while let range = range(of: string, options: options, range: start..<endIndex) {
275-
result.append(range)
276-
start = range.lowerBound < range.upperBound ? range.upperBound : index(range.lowerBound, offsetBy: 1, limitedBy: endIndex) ?? endIndex
277-
}
278-
return result
279-
}
280-
281-
func split(by string: String, options: CompareOptions = .regularExpression, includeSeparators: Bool = false, omittingEmptySubsequences: Bool = true) -> [String] {
282-
var result: [String] = []
283-
var start = startIndex
284-
while let range = range(of: string, options: options, range: start..<endIndex) {
285-
// Prevent empty strings
286-
if omittingEmptySubsequences, start < range.lowerBound {
287-
result.append(String(self[start..<range.lowerBound]))
288-
}
289-
if includeSeparators {
290-
result.append(String(self[range]))
291-
}
292-
start = range.upperBound
293-
}
294-
295-
if omittingEmptySubsequences, start < endIndex {
296-
result.append(String(self[start...]))
297-
}
298-
return result
299-
}
300-
301-
/// This version supports capture groups, wheres the one above doesn't
302-
func split(by captureRegex: NSRegularExpression) -> [String] {
303-
// Find the matching capture groups
304-
let selfRange = NSRange(startIndex..<endIndex, in: self)
305-
let matches = captureRegex.matches(in: self, options: [], range: selfRange)
306-
307-
if matches.isEmpty { return [self] }
308-
309-
var result: [String] = []
310-
var start = startIndex
311-
312-
for match in matches {
313-
// IMPORTANT: convert from NSRange to Range<String.Index>
314-
// https://stackoverflow.com/questions/75543272/convert-a-given-utf8-nsrange-in-a-string-to-a-utf16-nsrange
315-
guard let matchRange = Range(match.range, in: self) else { continue }
316-
317-
// Add text before the match
318-
if start < matchRange.lowerBound {
319-
result.append(String(self[start..<matchRange.lowerBound]))
320-
}
321-
322-
// Move start to after the match
323-
start = matchRange.upperBound
324-
325-
// Append separator, supporting capture groups
326-
for r in (0..<match.numberOfRanges).reversed() {
327-
let nsRange = match.range(at: r)
328-
if let sepRange = Range(nsRange, in: self) {
329-
result.append(String(self[sepRange]))
330-
break
331-
}
332-
}
333-
}
334-
335-
// Append remaining suffix
336-
if start < endIndex {
337-
result.append(String(self[start...]))
338-
}
339-
340-
return result
341-
}
342-
}
343-
344-
public enum SplitDelimiterBehavior {
345-
case removed
346-
case isolated
347-
case mergedWithPrevious
348-
case mergedWithNext
349-
}
350-
351-
public extension String {
352-
func split(by string: String, options: CompareOptions = .regularExpression, behavior: SplitDelimiterBehavior) -> [String] {
353-
func mergedWithNext(ranges: [Range<String.Index>]) -> [Range<String.Index>] {
354-
var merged: [Range<String.Index>] = []
355-
var currentStart = startIndex
356-
for range in ranges {
357-
if range.lowerBound == startIndex { continue }
358-
let mergedRange = currentStart..<range.lowerBound
359-
currentStart = range.lowerBound
360-
merged.append(mergedRange)
361-
}
362-
if currentStart < endIndex {
363-
merged.append(currentStart..<endIndex)
364-
}
365-
return merged
366-
}
367-
368-
func mergedWithPrevious(ranges: [Range<String.Index>]) -> [Range<String.Index>] {
369-
var merged: [Range<String.Index>] = []
370-
var currentStart = startIndex
371-
for range in ranges {
372-
let mergedRange = currentStart..<range.upperBound
373-
currentStart = range.upperBound
374-
merged.append(mergedRange)
375-
}
376-
if currentStart < endIndex {
377-
merged.append(currentStart..<endIndex)
378-
}
379-
return merged
380-
}
381-
382-
switch behavior {
383-
case .removed:
384-
return split(by: string, options: options, includeSeparators: false)
385-
case .isolated:
386-
return split(by: string, options: options, includeSeparators: true)
387-
case .mergedWithNext:
388-
// Obtain ranges and merge them
389-
// "the-final--countdown" -> (3, 4), (9, 10), (10, 11) -> (start, 2), (3, 8), (9, 9), (10, end)
390-
let ranges = ranges(of: string, options: options)
391-
let merged = mergedWithNext(ranges: ranges)
392-
return merged.map { String(self[$0]) }
393-
case .mergedWithPrevious:
394-
// Obtain ranges and merge them
395-
// "the-final--countdown" -> (3, 4), (9, 10), (10, 11) -> (start, 3), (4, 9), (10, 10), (11, end)
396-
let ranges = ranges(of: string, options: options)
397-
let merged = mergedWithPrevious(ranges: ranges)
398-
return merged.map { String(self[$0]) }
399-
}
400-
}
401-
}

0 commit comments

Comments
 (0)