Skip to content

Commit 52a6f59

Browse files
authored
Add documentation comments for public APIs (#268)
1 parent fb9fce2 commit 52a6f59

27 files changed

+1292
-83
lines changed

Sources/Generation/CoreML+Extensions.swift

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,15 @@ import CoreML
1111
import Foundation
1212

1313
extension MLMultiArray {
14-
/// All values will be stored in the last dimension of the MLMultiArray (default is dims=1)
14+
/// Creates an MLMultiArray from an array of integers.
15+
///
16+
/// All values are stored in the last dimension of the MLMultiArray, with leading
17+
/// dimensions set to 1. For example, with dims=2, the shape becomes [1, arr.count].
18+
///
19+
/// - Parameters:
20+
/// - arr: Array of integers to convert
21+
/// - dims: Number of dimensions for the resulting MLMultiArray
22+
/// - Returns: MLMultiArray containing the integer values
1523
static func from(_ arr: [Int], dims: Int = 1) -> MLMultiArray {
1624
var shape = Array(repeating: 1, count: dims)
1725
shape[shape.count - 1] = arr.count
@@ -27,7 +35,15 @@ extension MLMultiArray {
2735
return o
2836
}
2937

30-
/// All values will be stored in the last dimension of the MLMultiArray (default is dims=1)
38+
/// Creates an MLMultiArray from an array of doubles.
39+
///
40+
/// All values are stored in the last dimension of the MLMultiArray, with leading
41+
/// dimensions set to 1. For example, with dims=2, the shape becomes [1, arr.count].
42+
///
43+
/// - Parameters:
44+
/// - arr: Array of doubles to convert
45+
/// - dims: Number of dimensions for the resulting MLMultiArray
46+
/// - Returns: MLMultiArray containing the double values
3147
static func from(_ arr: [Double], dims: Int = 1) -> MLMultiArray {
3248
var shape = Array(repeating: 1, count: dims)
3349
shape[shape.count - 1] = arr.count
@@ -43,7 +59,13 @@ extension MLMultiArray {
4359
return o
4460
}
4561

46-
/// This will concatenate all dimensions into one one-dim array.
62+
/// Converts an MLMultiArray to a flat array of integers.
63+
///
64+
/// Concatenates all dimensions into a single one-dimensional array by reading
65+
/// the MLMultiArray data in memory order.
66+
///
67+
/// - Parameter o: MLMultiArray to convert
68+
/// - Returns: Flat array of integer values
4769
static func toIntArray(_ o: MLMultiArray) -> [Int] {
4870
var arr = Array(repeating: 0, count: o.count)
4971
let ptr = UnsafeMutablePointer<Int32>(OpaquePointer(o.dataPointer))
@@ -53,9 +75,18 @@ extension MLMultiArray {
5375
return arr
5476
}
5577

78+
/// Converts this MLMultiArray to a flat array of integers.
79+
///
80+
/// - Returns: Flat array of integer values
5681
func toIntArray() -> [Int] { Self.toIntArray(self) }
5782

58-
/// This will concatenate all dimensions into one one-dim array.
83+
/// Converts an MLMultiArray to a flat array of doubles.
84+
///
85+
/// Concatenates all dimensions into a single one-dimensional array by reading
86+
/// the MLMultiArray data in memory order.
87+
///
88+
/// - Parameter o: MLMultiArray to convert
89+
/// - Returns: Flat array of double values
5990
static func toDoubleArray(_ o: MLMultiArray) -> [Double] {
6091
var arr: [Double] = Array(repeating: 0, count: o.count)
6192
let ptr = UnsafeMutablePointer<Double>(OpaquePointer(o.dataPointer))
@@ -65,11 +96,17 @@ extension MLMultiArray {
6596
return arr
6697
}
6798

99+
/// Converts this MLMultiArray to a flat array of doubles.
100+
///
101+
/// - Returns: Flat array of double values
68102
func toDoubleArray() -> [Double] { Self.toDoubleArray(self) }
69103

70-
/// Helper to construct a sequentially-indexed multi array,
71-
/// useful for debugging and unit tests
72-
/// Example in 3 dimensions:
104+
/// Creates a test MLMultiArray with sequentially indexed values.
105+
///
106+
/// Useful for debugging and unit tests. Values are assigned sequentially
107+
/// starting from 0, following the memory layout of the specified shape.
108+
///
109+
/// Example output for shape [2, 3, 4]:
73110
/// ```
74111
/// [[[ 0, 1, 2, 3 ],
75112
/// [ 4, 5, 6, 7 ],
@@ -78,6 +115,9 @@ extension MLMultiArray {
78115
/// [ 16, 17, 18, 19 ],
79116
/// [ 20, 21, 22, 23 ]]]
80117
/// ```
118+
///
119+
/// - Parameter shape: Desired shape of the test tensor
120+
/// - Returns: MLMultiArray with sequential values for testing
81121
static func testTensor(shape: [Int]) -> MLMultiArray {
82122
let arr = try! MLMultiArray(shape: shape as [NSNumber], dataType: .double)
83123
let ptr = UnsafeMutablePointer<Double>(OpaquePointer(arr.dataPointer))
@@ -199,6 +239,12 @@ extension MLMultiArray {
199239
}
200240

201241
extension MLShapedArray<Float> {
242+
/// Efficiently extracts float values from the shaped array.
243+
///
244+
/// Uses optimized memory copying when possible (stride=1), falling back to
245+
/// slower scalar access for non-contiguous arrays.
246+
///
247+
/// - Returns: Array of Float values from the shaped array
202248
var floats: [Float] {
203249
guard strides.first == 1, strides.count == 1 else {
204250
// For some reason this path is slow.
@@ -213,6 +259,12 @@ extension MLShapedArray<Float> {
213259
}
214260

215261
extension MLShapedArraySlice<Float> {
262+
/// Efficiently extracts float values from the shaped array slice.
263+
///
264+
/// Uses optimized memory copying when possible (stride=1), falling back to
265+
/// slower scalar access for non-contiguous slices.
266+
///
267+
/// - Returns: Array of Float values from the shaped array slice
216268
var floats: [Float] {
217269
guard strides.first == 1, strides.count == 1 else {
218270
// For some reason this path is slow.
@@ -227,6 +279,12 @@ extension MLShapedArraySlice<Float> {
227279
}
228280

229281
extension MLMultiArray {
282+
/// Efficiently extracts float values from the MLMultiArray if it contains float32 data.
283+
///
284+
/// Uses fast memory copying to extract all float values as a contiguous array.
285+
/// Returns nil if the array doesn't contain float32 data.
286+
///
287+
/// - Returns: Array of Float values, or nil if not float32 type
230288
var floats: [Float]? {
231289
guard dataType == .float32 else { return nil }
232290

Sources/Generation/Generation.swift

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,62 @@ import CoreML
1010

1111
import Tokenizers
1212

13+
/// Supported text generation modes.
1314
public enum GenerationMode {
15+
/// Contrastive search generation mode
1416
case contrastiveSearch
17+
/// Greedy decoding generation mode
1518
case greedy
19+
/// Sampling-based generation mode
1620
case sample
21+
/// Beam search generation mode
1722
case beam
23+
/// Group beam search generation mode
1824
case groupBeam
25+
/// Unsupported generation mode
1926
case unsupported
2027
}
2128

29+
/// Array of token IDs representing input tokens.
2230
public typealias InputTokens = [Int]
31+
32+
/// Array of token IDs representing generated output tokens.
2333
public typealias GenerationOutput = [Int]
2434

25-
/// A callable (a model, usually), that predicts the next token after a given sequence
35+
/// A callable model that predicts the next token after a given sequence.
36+
///
37+
/// - Parameter tokens: Input token sequence
38+
/// - Parameter config: Generation configuration
39+
/// - Returns: Logits array for next token prediction
2640
public typealias NextTokenModel = (InputTokens, GenerationConfig) -> any MLShapedArrayProtocol
2741

42+
/// Callback for receiving generated tokens during streaming.
2843
public typealias PredictionTokensCallback = (GenerationOutput) -> Void
44+
45+
/// Callback for receiving generated text during streaming.
2946
public typealias PredictionStringCallback = (String) -> Void
3047

31-
// TODO: callbacks (for streaming)
48+
/// Protocol for text generation implementations.
3249
public protocol Generation {
50+
/// Performs greedy search generation.
51+
///
52+
/// - Parameters:
53+
/// - config: Generation configuration
54+
/// - tokens: Input token sequence
55+
/// - model: Model for next token prediction
56+
/// - callback: Optional callback for streaming tokens
57+
/// - Returns: Generated token sequence
3358
func greedySearch(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback?) async -> GenerationOutput
3459

60+
/// Generates text from a prompt string.
61+
///
62+
/// - Parameters:
63+
/// - config: Generation configuration
64+
/// - prompt: Input prompt text
65+
/// - model: Model for next token prediction
66+
/// - tokenizer: Tokenizer for encoding/decoding
67+
/// - callback: Optional callback for streaming text
68+
/// - Returns: Generated text string
3569
func generate(config: GenerationConfig, prompt: String, model: NextTokenModel, tokenizer: Tokenizer, callback: PredictionStringCallback?) async -> String
3670
}
3771

@@ -50,7 +84,19 @@ public extension Generation {
5084
return outputTokens
5185
}
5286

53-
/// https://github.com/huggingface/transformers/blob/42017d82baa083da2bee3055fdac80c81ee97b8a/src/transformers/generation/utils.py#L1552
87+
/// Performs sampling-based text generation with configurable logits warping.
88+
///
89+
/// Uses various logits warpers (temperature, top-k, top-p, repetition penalty) to modify
90+
/// token probabilities before sampling, enabling diverse and controllable text generation.
91+
///
92+
/// - Parameters:
93+
/// - config: Generation configuration with sampling parameters
94+
/// - tokens: Input token sequence
95+
/// - model: Model for next token prediction
96+
/// - callback: Optional callback for streaming tokens
97+
/// - Returns: Generated token sequence
98+
///
99+
/// - Note: Based on https://github.com/huggingface/transformers/blob/42017d82baa083da2bee3055fdac80c81ee97b8a/src/transformers/generation/utils.py#L1552
54100
func sample(config: GenerationConfig, tokens: InputTokens, model: NextTokenModel, callback: PredictionTokensCallback? = nil) async -> GenerationOutput {
55101
// Iterate until we find the eos token or reach the max length
56102
// TODO: additional stopping criteria

Sources/Generation/GenerationConfig.swift

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,65 @@
77

88
import Foundation
99

10-
/// Essentials taken from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/configuration_utils.py
10+
/// Configuration parameters for text generation algorithms.
11+
///
12+
/// Contains all the parameters needed to control various aspects of text generation,
13+
/// including sampling parameters, beam search settings, and special token IDs.
14+
///
15+
/// - Note: Based on https://github.com/huggingface/transformers/blob/main/src/transformers/generation/configuration_utils.py
1116
public struct GenerationConfig {
17+
/// Maximum total length of the generated sequence (input + output tokens).
1218
public var maxLength = 20
19+
20+
/// Maximum number of new tokens to generate.
1321
public var maxNewTokens: Int
22+
23+
/// Whether to use sampling instead of deterministic decoding.
1424
public var doSample = false
25+
26+
/// Number of beams for beam search (1 for greedy decoding).
1527
public var numBeams = 1
28+
29+
/// Number of beam groups for group beam search.
1630
public var numBeamGroups = 1
31+
32+
/// Penalty parameter for contrastive search.
1733
public var penaltyAlpha: Double?
34+
35+
/// Temperature for sampling (higher values increase randomness).
1836
public var temperature = 1.0
37+
38+
/// Number of top tokens to consider for top-k sampling.
1939
public var topK = 50
40+
41+
/// Cumulative probability threshold for top-p sampling.
2042
public var topP = 1.0
43+
44+
/// Penalty for token repetition (1.0 means no penalty).
2145
public var repetitionPenalty = 1.0
2246

47+
/// Token ID used for padding sequences.
2348
public var padTokenId: Int?
49+
50+
/// Token ID for beginning of sequence.
2451
public var bosTokenId: Int?
52+
53+
/// Token ID for end of sequence.
2554
public var eosTokenId: Int?
2655

56+
/// Creates a new generation configuration.
57+
///
58+
/// - Parameters:
59+
/// - maxLength: Maximum total sequence length
60+
/// - maxNewTokens: Maximum new tokens to generate
61+
/// - doSample: Enable sampling instead of greedy decoding
62+
/// - numBeams: Number of beams for beam search
63+
/// - numBeamGroups: Number of beam groups for group beam search
64+
/// - penaltyAlpha: Penalty parameter for contrastive search
65+
/// - temperature: Sampling temperature
66+
/// - topK: Top-k sampling parameter
67+
/// - topP: Top-p sampling parameter
68+
/// - repetitionPenalty: Repetition penalty factor
2769
public init(maxLength: Int = 20, maxNewTokens: Int, doSample: Bool = false, numBeams: Int = 1, numBeamGroups: Int = 1, penaltyAlpha: Double? = nil, temperature: Double = 1.0, topK: Int = 50, topP: Double = 1.0, repetitionPenalty: Double = 1.0) {
2870
self.maxLength = maxLength
2971
self.maxNewTokens = maxNewTokens
@@ -39,6 +81,12 @@ public struct GenerationConfig {
3981
}
4082

4183
public extension GenerationConfig {
84+
/// Determines the appropriate generation mode based on configuration parameters.
85+
///
86+
/// Analyzes the combination of sampling settings, beam parameters, and penalty values
87+
/// to automatically select the most appropriate generation algorithm.
88+
///
89+
/// - Returns: The determined generation mode
4290
var generationMode: GenerationMode {
4391
// Exclude this case from the pattern matching below
4492
if topK > 1, !doSample, penaltyAlpha != nil, penaltyAlpha! > 0 {

Sources/Generation/LogitsWarper/LogitsProcessor.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,27 @@
11
import Foundation
22

3+
/// Processes logits by applying a sequence of logits warpers.
4+
///
5+
/// Coordinates the application of multiple logits warpers in sequence,
6+
/// allowing for complex probability transformations during text generation.
37
public struct LogitsProcessor {
8+
/// Array of logits warpers to apply in sequence.
49
public var logitsWarpers: [any LogitsWarper]
510

11+
/// Creates a new logits processor.
12+
///
13+
/// - Parameter logitsWarpers: Array of warpers to apply in sequence
614
public init(logitsWarpers: [any LogitsWarper]) {
715
self.logitsWarpers = logitsWarpers
816
}
917

18+
/// Processes logits by applying all warpers in sequence.
19+
///
20+
/// Each warper is applied to the output of the previous warper, allowing
21+
/// for complex chaining of probability transformations.
22+
///
23+
/// - Parameter arr: Input logits array
24+
/// - Returns: Tuple of processed (indices, logits)
1025
public func callAsFunction(_ arr: [Float]) -> (indices: [Int], logits: [Float]) {
1126
var indices = Array(arr.indices)
1227
var logits = arr

Sources/Generation/LogitsWarper/LogitsWarper.swift

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,34 @@
11
import Foundation
22

3-
/// Protocol for all logit warpers that can be applied during generation
3+
/// Protocol for logits warpers that transform token probabilities during generation.
4+
///
5+
/// Logits warpers modify the probability distribution over tokens before sampling,
6+
/// enabling techniques like temperature scaling, top-k/top-p filtering, and repetition penalties.
47
public protocol LogitsWarper {
8+
/// Warps the logits and corresponding indices.
9+
///
10+
/// - Parameters:
11+
/// - indices: Array of token indices corresponding to the logits
12+
/// - logits: Array of logit values to transform
13+
/// - Returns: Tuple of transformed (indices, logits)
514
func warp(indices: [Int], logits: [Float]) -> (indices: [Int], logits: [Float])
15+
16+
/// Convenience method that calls the warp function.
17+
///
18+
/// - Parameters:
19+
/// - indices: Array of token indices
20+
/// - logits: Array of logit values
21+
/// - Returns: Tuple of transformed (indices, logits)
622
func callAsFunction(_ indices: [Int], _ logits: [Float]) -> (indices: [Int], logits: [Float])
723
}
824

925
public extension LogitsWarper {
26+
/// Default implementation of callAsFunction that delegates to warp.
27+
///
28+
/// - Parameters:
29+
/// - indices: Array of token indices
30+
/// - logits: Array of logit values
31+
/// - Returns: Tuple of transformed (indices, logits)
1032
func callAsFunction(_ indices: [Int], _ logits: [Float]) -> (indices: [Int], logits: [Float]) {
1133
warp(indices: indices, logits: logits)
1234
}

0 commit comments

Comments
 (0)