@@ -5,24 +5,51 @@ import CoreML
55
66@available ( macOS 15 . 0 , iOS 18 . 0 , * )
77func selectNextTokenUsingGreedyDecoding( from scores: MLTensor ) -> MLTensor {
8- scores. argmax ( alongAxis: - 1 ) . reshaped ( to: [ 1 , 1 ] )
8+ let indices = scores. argmax ( alongAxis: - 1 ) . reshaped ( to: [ 1 , 1 ] )
9+ // Ensure indices are Int32 for concatenation with input tokens
10+ return indices. scalarType == Int32 . self ? indices : indices. cast ( to: Int32 . self)
911}
1012
11- // MARK: Top-K Sampling
13+ // MARK: Sampling
1214
15+ /// Performs multinomial sampling from processed logits.
16+ ///
17+ /// Assumes logits have already been processed by LogitsProcessorList
18+ /// (temperature, top-k, top-p, etc. already applied).
19+ ///
20+ /// - Parameter scores: Processed logits tensor [batch_size, vocab_size]
21+ /// - Returns: Sampled token ID tensor [batch_size, 1]
1322@available ( macOS 15 . 0 , iOS 18 . 0 , * )
14- func selectNextTokenUsingTopKSampling( from scores: MLTensor , temperature: Float , topK: Int ) -> MLTensor {
15- let temperatureAdjustedScores = scores / temperature
16- let ( topKScores, topKIndices) = temperatureAdjustedScores. topK ( topK)
17- let topKProbs = topKScores. softmax ( alongAxis: - 1 )
18- let rnd = topKProbs. sum ( ) * Float. random ( in: 0 ..< 1 )
19- var accumTopKProbs = topKProbs. cumulativeSum ( alongAxis: - 1 )
20- accumTopKProbs += ( accumTopKProbs .< rnd) * 100.0
21- let topKIndex = accumTopKProbs. argsort ( ) [ ... , 0 ]
22- let nextTokenTensor = topKIndices. gathering (
23- atIndices: topKIndex,
24- alongAxis: topKIndices. rank - 1
25- )
26- return nextTokenTensor. reshaped ( to: [ 1 , 1 ] )
23+ func selectNextTokenUsingSampling( from scores: MLTensor ) -> MLTensor {
24+ // Convert logits to probabilities
25+ let probs = scores. softmax ( alongAxis: - 1 )
26+
27+ // Multinomial sampling using cumulative sum method:
28+ // 1. Generate random number in [0, 1)
29+ // 2. Compute cumulative sum of probabilities
30+ // 3. Find first index where cumsum >= random_number
31+ //
32+ // This is equivalent to torch.multinomial() but using available MLTensor ops
33+
34+ let batchSize = scores. shape [ 0 ]
35+ let rndTensor = MLTensor ( randomUniform: [ batchSize, 1 ] , in: 0 ..< 1 , scalarType: Float . self)
36+ let cumulativeProbs = probs. cumulativeSum ( alongAxis: - 1 )
37+
38+ // Ensure random tensor matches the type of cumulativeProbs
39+ let rnd = cumulativeProbs. scalarType == Float . self ? rndTensor : rndTensor. cast ( to: cumulativeProbs. scalarType)
40+
41+ // Create mask where cumsum >= rnd (these are candidates)
42+ // We want the FIRST position where this is true
43+ // Strategy: Set all positions where cumsum < rnd to a large value (1000.0)
44+ // Set all positions where cumsum >= rnd to their index value
45+ // Then argmin will give us the first qualifying index
46+
47+ let mask = cumulativeProbs .< rnd
48+ let penalized = mask * 1000.0 // Large value for positions to skip
49+ let indexed = penalized + cumulativeProbs // Positions >= rnd will have small values
50+
51+ let sampledIndex = indexed. argmin ( alongAxis: - 1 ) . reshaped ( to: [ 1 , 1 ] )
52+ // Ensure indices are Int32 for concatenation with input tokens
53+ return sampledIndex. scalarType == Int32 . self ? sampledIndex : sampledIndex. cast ( to: Int32 . self)
2754}
2855#endif // canImport(CoreML)
0 commit comments