@@ -4,139 +4,151 @@ import CoreML
44import Accelerate
55
66
7- public protocol Embedding { }
7+ class BERTEmbedding {
88
9- public struct AutoEmbedding { } // Otherwise AutoModel
10-
11- extension AutoEmbedding {
12- public static func from( pretrained model: String , hubApi: HubApi = . shared) async throws -> Embedding {
13- return try await BGEM3Model ( repoName: model, hubApi: hubApi)
14- }
15- }
16-
17- class BERTEmbedding : Embedding { // Otherwise BERTModel
18- private let wordEmbedding : BNNS . EmbeddingLayer
19- private let positionEmbedding : BNNS . EmbeddingLayer
20- private let tokenTypeEmbedding : BNNS . EmbeddingLayer
21- private let normalization : BNNS . NormalizationLayer
22- private let dropout : BNNS . DropoutLayer
23-
24- private let positionEmbeddingType = " absolute "
25-
26- init ( repoName: String ) { fatalError ( ) }
27-
28- public func callAsFunction( inputIds: MLMultiArray ? = nil ,
29- tokenTypeIDs: MLMultiArray ? = nil ,
30- positionIDs: MLMultiArray ? = nil ,
31- inputEmbeds: MLMultiArray ? = nil ,
32- pastKeyValuesLength: Int = 0 ) -> MLMultiArray {
33- fatalError ( )
34- }
35- }
36-
37- class BGEM3Model : Embedding {
38-
39- struct Output {
40- let lastHidddenState : MLMultiArray // batchSize, sequenceLength, hiddenSize
41- let hiddenStates : MLMultiArray ?
42- let attentions : MLMultiArray ?
43-
44- let loss : MLMultiArray ?
45- let scores : MLMultiArray ?
46- let pReps : MLMultiArray ?
47- let qReps : MLMultiArray ?
48- }
49-
50- let withSparse = false
51- let withDense = true
52- let withColbert = false
53-
54- let shouldNormalize = false
55- // let poolingMethod = "cls"
56- // let negativesCrossDevice = false
57- // let temperature = 1.0
58- // let enableSubBatch = true
59- // let unifiedFinetuning = true
60- // let useSelfDistill = false
61- // let colbertDim: Int? = nil
62- // let selfDistillStartStep: Int? = nil
63-
64- private let tokenizer : Tokenizer
65- private let denseLayer : BNNS . FullyConnectedLayer
66- private let sparseLayer : BNNS . FullyConnectedLayer
67- private let colbertLayer : BNNS . FullyConnectedLayer
68-
69- init ( repoName: String , hubApi: HubApi ) async throws {
70- let config = LanguageModelConfigurationFromHub ( modelName: repoName)
71- self . tokenizer = try await AutoTokenizer . from ( pretrained: repoName, hubApi: hubApi)
72-
73- let hiddenSize = try await config. modelConfig. hiddenSize? . intValue ?? 384
74- let colbertDim : Int ? = nil
75- let denseInput = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( hiddenSize, stride: 2 ) )
76- let denseOutput = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( colbertDim ?? hiddenSize, stride: 2 ) )
77- let denseWeights = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( hiddenSize, stride: 2 ) )
78- self . denseLayer = BNNS . FullyConnectedLayer ( input: denseInput, output: denseOutput, weights: denseWeights, bias: nil , activation: . identity) !
9+ typealias Weights = [ String : MLMultiArray ]
10+
11+ var shape : [ NSNumber ] { [
12+ NSNumber ( value: maxPositionEmbeddings) ,
13+ NSNumber ( value: hiddenSize) ,
14+ ] }
15+
16+ private let weights : Weights
17+
18+ private let positionEmbeddingType : String
19+ private let hiddenSize : Int
20+ private let vocabSize : Int
21+ private let maxPositionEmbeddings : Int
22+ private let typeVocabSize : Int
23+ private let padTokenID : Int
24+ private let normalizationEpsilon : Float
25+ private let dropoutRate : Float = 1e-1
26+ private let hiddenActivation : BNNS . ActivationFunction = . geluApproximation2( alpha: 1e-1 , beta: 1e-1 )
27+
28+ private var allocations : [ BNNSNDArrayDescriptor ] = [ ]
29+
30+ private lazy var wordEmbedding : BNNS . EmbeddingLayer = {
31+ let input = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Int64 . self, shape: . vector( maxPositionEmbeddings) )
32+ allocations. append ( input)
33+ let dictData : [ Float32 ] = weights [ " bert.embeddings.word_embeddings.weight " ] !. toArray ( )
34+ let dict = BNNSNDArrayDescriptor . allocate ( initializingFrom: dictData, shape: . matrixColumnMajor( hiddenSize, vocabSize) )
35+ allocations. append ( dict)
36+ let output = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
37+ allocations. append ( output)
7938
80- let sparseInput = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( hiddenSize, stride: 2 ) )
81- let sparseOutput = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( 1 , stride: 2 ) )
82- let sparseWeights = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( hiddenSize, stride: 2 ) )
83- self . sparseLayer = BNNS . FullyConnectedLayer ( input: sparseInput, output: sparseOutput, weights: sparseWeights, bias: nil , activation: . identity) !
39+ return BNNS . EmbeddingLayer ( input: input, output: output, dictionary: dict, paddingIndex: 0 , maximumNorm: 0 , normType: . l2, scalesGradientByFrequency: false ) !
40+ } ( )
41+
42+ private lazy var positionEmbedding : BNNS . EmbeddingLayer = {
43+ let input = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Int64 . self, shape: . vector( maxPositionEmbeddings) )
44+ allocations. append ( input)
45+ let dictData : [ Float32 ] = weights [ " bert.embeddings.position_embeddings.weight " ] !. toArray ( )
46+ let dict = BNNSNDArrayDescriptor . allocate ( initializingFrom: dictData, shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
47+ allocations. append ( dict)
48+ let output = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
49+ allocations. append ( output)
50+
51+ return BNNS . EmbeddingLayer ( input: input, output: output, dictionary: dict, paddingIndex: - 1 , maximumNorm: 0 , normType: . l2, scalesGradientByFrequency: true ) !
52+ } ( )
53+
54+ private lazy var tokenTypeEmbedding : BNNS . EmbeddingLayer = {
55+ let input = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Int64 . self, shape: . vector( maxPositionEmbeddings) )
56+ allocations. append ( input)
57+ let dictData : [ Float32 ] = weights [ " bert.embeddings.token_type_embeddings.weight " ] !. toArray ( )
58+ let dict = BNNSNDArrayDescriptor . allocate ( initializingFrom: dictData, shape: . matrixColumnMajor( hiddenSize, typeVocabSize) )
59+ allocations. append ( dict)
60+ let output = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
61+ allocations. append ( output)
8462
85- let colbertInput = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( hiddenSize, stride: 2 ) )
86- let colbertOutput = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( 1 , stride: 2 ) )
87- let colbertWeights = BNNSNDArrayDescriptor ( dataType: . float16, shape: . vector( hiddenSize, stride: 2 ) )
88- self . colbertLayer = BNNS . FullyConnectedLayer ( input: colbertInput, output: colbertOutput, weights: colbertWeights, bias: nil , activation: . identity) !
89- }
90-
91- public func callAsFunction( _ textInput: ( indices: MLMultiArray , attentionMask: MLMultiArray ) ) -> Output {
92- fatalError ( )
63+ return BNNS . EmbeddingLayer ( input: input, output: output, dictionary: dict, paddingIndex: - 1 , maximumNorm: 0 , normType: . l2, scalesGradientByFrequency: true ) !
64+ } ( )
65+
66+ private lazy var normalization : BNNS . NormalizationLayer = {
67+ let input = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixRowMajor( maxPositionEmbeddings, hiddenSize) )
68+ allocations. append ( input)
69+ let output = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixRowMajor( maxPositionEmbeddings, hiddenSize) )
70+ allocations. append ( output)
71+
72+ let betaWA : MLMultiArray ! = weights [ " bert.embeddings.LayerNorm.beta " ] ?? weights [ " bert.embeddings.LayerNorm.bias " ]
73+ let beta = BNNSNDArrayDescriptor . allocate ( initializingFrom: betaWA. toArray ( ) as [ Float32 ] , shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
74+ allocations. append ( beta)
75+
76+ let gammaWA : MLMultiArray ! = weights [ " bert.embeddings.LayerNorm.gamma " ] ?? weights [ " bert.embeddings.LayerNorm.weight " ]
77+ let gamma = BNNSNDArrayDescriptor . allocate ( initializingFrom: gammaWA. toArray ( ) as [ Float32 ] , shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
78+ allocations. append ( gamma)
79+
80+ return BNNS . NormalizationLayer ( type: . batch( movingMean: nil , movingVariance: nil ) , input: input, output: output, beta: beta, gamma: gamma, epsilon: normalizationEpsilon, activation: hiddenActivation) !
81+ } ( )
82+
83+ private lazy var dropout : BNNS . DropoutLayer = {
84+ let input = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
85+ allocations. append ( input)
86+ let output = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
87+ allocations. append ( output)
88+
89+ return BNNS . DropoutLayer ( input: input, output: output, rate: dropoutRate, seed: 0 , control: 0 ) !
90+ } ( )
91+
92+ deinit {
93+ allocations. forEach ( { $0. deallocate ( ) } )
9394 }
9495
95- private func forward( textInput: ( indices: MLMultiArray , attentionMask: MLMultiArray ) ) -> [ String : MLMultiArray ] {
96- let lastHiddenState = self ( textInput) . lastHidddenState
97-
98- var output = [ String: MLMultiArray] ( )
99- if withDense {
100- output [ " dense " ] = self . dense ( hiddenState: lastHiddenState, mask: textInput. attentionMask)
101- }
102- if withSparse {
103- output [ " sparse " ] = self . sparse ( hiddenState: lastHiddenState, mask: textInput. attentionMask)
96+ init ( config: Config , weights: Weights = [ : ] ) {
97+ assert ( config. model_type!. stringValue == " bert " )
98+ for key in [
99+ " bert.embeddings.word_embeddings.weight " ,
100+ " bert.embeddings.position_embeddings.weight " ,
101+ " bert.embeddings.token_type_embeddings.weight " ,
102+ ] { assert ( weights. keys. contains ( where: { $0 == key } ) ) }
103+ assert ( weights. keys. contains ( where: { $0 == " bert.embeddings.LayerNorm.beta " || $0 == " bert.embeddings.LayerNorm.bias " } ) )
104+ assert ( weights. keys. contains ( where: { $0 == " bert.embeddings.LayerNorm.gamma " || $0 == " bert.embeddings.LayerNorm.weight " } ) )
105+ assert ( config. hidden_act!. stringValue == " gelu " )
106+ assert ( " absolute " == config. position_embedding_type!. stringValue!)
107+ self . positionEmbeddingType = config. position_embedding_type!. stringValue!
108+ self . hiddenSize = config. hidden_size!. intValue!
109+ self . vocabSize = config. vocab_size!. intValue!
110+ self . maxPositionEmbeddings = config. max_position_embeddings!. intValue!
111+ self . typeVocabSize = config. type_vocab_size!. intValue!
112+ self . padTokenID = config. pad_token_id!. intValue!
113+ self . normalizationEpsilon = Float ( config. layer_norm_eps!. doubleValue!)
114+ self . weights = weights
115+ }
116+
117+ public func callAsFunction( inputIDs: [ Int64 ] ,
118+ tokenTypeIDs: [ Int64 ] ? = nil ,
119+ positionIDs: [ Int64 ] ? = nil ) -> MLMultiArray {
120+ let inputLength = inputIDs. count
121+ let inputIDs : [ Int64 ] = inputIDs. padded ( length: maxPositionEmbeddings)
122+ let wordInput = BNNSNDArrayDescriptor . allocate ( initializingFrom: inputIDs, shape: . vector( inputIDs. count) )
123+ let wordOutput = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixColumnMajor( hiddenSize, inputIDs. count) )
124+ defer {
125+ wordInput. deallocate ( )
126+ wordOutput. deallocate ( )
104127 }
105- if withColbert {
106- output [ " colbert " ] = self . colbert ( hiddenState: lastHiddenState, mask: textInput. attentionMask)
128+ try ! wordEmbedding. apply ( batchSize: 1 , input: wordInput, output: wordOutput)
129+
130+ let positionIDs = positionIDs ?? Array < Int64 > ( stride ( from: 0 , through: Int64 ( inputLength - 1 ) , by: 1 ) )
131+ let positionInput = BNNSNDArrayDescriptor . allocate ( initializingFrom: positionIDs. padded ( length: maxPositionEmbeddings) , shape: . vector( maxPositionEmbeddings) )
132+ let positionOutput = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
133+ defer {
134+ positionInput. deallocate ( )
135+ positionOutput. deallocate ( )
107136 }
108-
109- if shouldNormalize {
110- if withDense {
111- // TODO: Normalize output["dense"] =
112- fatalError ( )
113- }
114- if withColbert {
115- // TODO: Normalize output["colbert"] =
116- fatalError ( )
117- }
137+ try ! self . positionEmbedding. apply ( batchSize: 1 , input: positionInput, output: positionOutput)
138+
139+ let tokenTypeIDs : [ Int64 ] = tokenTypeIDs ?? Array ( repeating: 0 , count: maxPositionEmbeddings)
140+ let typeInput = BNNSNDArrayDescriptor . allocate ( initializingFrom: tokenTypeIDs, shape: . vector( maxPositionEmbeddings) )
141+ let typeOutput = BNNSNDArrayDescriptor . allocateUninitialized ( scalarType: Float32 . self, shape: . matrixColumnMajor( hiddenSize, maxPositionEmbeddings) )
142+ defer {
143+ typeInput. deallocate ( )
144+ typeOutput. deallocate ( )
118145 }
146+ try ! self . tokenTypeEmbedding. apply ( batchSize: 1 , input: typeInput, output: typeOutput)
119147
120- return output
121- }
122-
123- private func dense( hiddenState: MLMultiArray , mask: MLMultiArray ) -> MLMultiArray {
124- assert ( hiddenState. shape. count == 2 )
125- var data = [ Float] ( )
126- data. reserveCapacity ( hiddenState. count)
127-
128- for index in 0 ..< hiddenState. count {
129- data. append ( hiddenState [ index] . floatValue)
130- }
131-
132- return try ! MLMultiArray ( data)
133- }
134-
135- private func sparse( hiddenState: MLMultiArray , mask: MLMultiArray ) -> MLMultiArray {
136- fatalError ( )
137- }
148+ let multiWord = try ! wordOutput. makeMultiArray ( of: Float32 . self, shape: shape)
149+ let multiPosition = try ! positionOutput. makeMultiArray ( of: Float32 . self, shape: shape)
150+ let multiType = try ! typeOutput. makeMultiArray ( of: Float32 . self, shape: shape)
138151
139- private func colbert( hiddenState: MLMultiArray , mask: MLMultiArray ) -> MLMultiArray {
140- fatalError ( )
152+ return multiWord + multiPosition + multiType
141153 }
142154}
0 commit comments