@@ -140,7 +140,7 @@ private async Task<PromptEmbeddingsResult> CreateEmbedsTwoAsync(PromptOptions pr
140140 /// Tokenize Prompt and NegativePrompt with Tokenizer2
141141 var promptTokens = await DecodeTextAsLongAsync ( promptOptions . Prompt ) ;
142142 var negativePromptTokens = await DecodeTextAsLongAsync ( promptOptions . NegativePrompt ) ;
143- var maxPromptTokenCount = Math . Max ( promptTokens . Length , negativePromptTokens . Length ) ;
143+ var maxPromptTokenCount = Math . Max ( promptTokens . InputIds . Length , negativePromptTokens . InputIds . Length ) ;
144144
145145 // Generate embeds for tokens
146146 var promptEmbeddings = await GenerateEmbedsAsync ( promptTokens , maxPromptTokenCount ) ;
@@ -174,7 +174,7 @@ private async Task<PromptEmbeddingsResult> CreateEmbedsBothAsync(PromptOptions p
174174 // Tokenize Prompt and NegativePrompt
175175 var promptTokens = await DecodePromptTextAsync ( promptOptions . Prompt ) ;
176176 var negativePromptTokens = await DecodePromptTextAsync ( promptOptions . NegativePrompt ) ;
177- var maxPromptTokenCount = Math . Max ( promptTokens . Length , negativePromptTokens . Length ) ;
177+ var maxPromptTokenCount = Math . Max ( promptTokens . InputIds . Length , negativePromptTokens . InputIds . Length ) ;
178178
179179 // Generate embeds for tokens
180180 var promptEmbeddings = await GeneratePromptEmbedsAsync ( promptTokens , maxPromptTokenCount ) ;
@@ -188,8 +188,8 @@ private async Task<PromptEmbeddingsResult> CreateEmbedsBothAsync(PromptOptions p
188188 var dualPromptEmbeddings = await GenerateEmbedsAsync ( dualPromptTokens , maxPromptTokenCount ) ;
189189 var dualNegativePromptEmbeddings = await GenerateEmbedsAsync ( dualNegativePromptTokens , maxPromptTokenCount ) ;
190190
191- var dualPrompt = promptEmbeddings . Concatenate ( dualPromptEmbeddings . PromptEmbeds , 2 ) ;
192- var dualNegativePrompt = negativePromptEmbeddings . Concatenate ( dualNegativePromptEmbeddings . PromptEmbeds , 2 ) ;
191+ var dualPrompt = promptEmbeddings . PromptEmbeds . Concatenate ( dualPromptEmbeddings . PromptEmbeds , 2 ) ;
192+ var dualNegativePrompt = negativePromptEmbeddings . PromptEmbeds . Concatenate ( dualNegativePromptEmbeddings . PromptEmbeds , 2 ) ;
193193 var pooledPromptEmbeds = dualPromptEmbeddings . PooledPromptEmbeds ;
194194 var pooledNegativePromptEmbeds = dualNegativePromptEmbeddings . PooledPromptEmbeds ;
195195
@@ -212,22 +212,21 @@ private async Task<PromptEmbeddingsResult> CreateEmbedsBothAsync(PromptOptions p
212212 /// </summary>
213213 /// <param name="inputText">The input text.</param>
214214 /// <returns></returns>
215- private async Task < long [ ] > DecodeTextAsLongAsync ( string inputText )
215+ private async Task < TokenizerResult > DecodeTextAsLongAsync ( string inputText )
216216 {
217217 if ( string . IsNullOrEmpty ( inputText ) )
218- return Array . Empty < long > ( ) ;
218+ return new TokenizerResult ( Array . Empty < long > ( ) , Array . Empty < long > ( ) ) ;
219219
220220 var metadata = await _tokenizer2 . GetMetadataAsync ( ) ;
221221 var inputTensor = new DenseTensor < string > ( new string [ ] { inputText } , new int [ ] { 1 } ) ;
222222 using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
223223 {
224224 inferenceParameters . AddInputTensor ( inputTensor ) ;
225225 inferenceParameters . AddOutputBuffer ( ) ;
226-
227- using ( var results = _tokenizer2 . RunInference ( inferenceParameters ) )
226+ inferenceParameters . AddOutputBuffer ( ) ;
227+ using ( var results = _tokenizer . RunInference ( inferenceParameters ) )
228228 {
229- var resultData = results . First ( ) . ToArray < long > ( ) ;
230- return resultData ;
229+ return new TokenizerResult ( results [ 0 ] . ToArray < long > ( ) , results [ 1 ] . ToArray < long > ( ) ) ;
231230 }
232231 }
233232 }
@@ -238,16 +237,16 @@ private async Task<long[]> DecodeTextAsLongAsync(string inputText)
238237 /// </summary>
239238 /// <param name="tokenizedInput">The tokenized input.</param>
240239 /// <returns></returns>
241- private async Task < EncoderResult > EncodeTokensAsync ( long [ ] tokenizedInput )
240+ private async Task < EncoderResult > EncodeTokensAsync ( TokenizerResult tokenizedInput )
242241 {
243242 var metadata = await _textEncoder2 . GetMetadataAsync ( ) ;
244- var inputTensor = new DenseTensor < long > ( tokenizedInput , new [ ] { 1 , tokenizedInput . Length } ) ;
243+ var inputTensor = new DenseTensor < long > ( tokenizedInput . InputIds , new [ ] { 1 , tokenizedInput . InputIds . Length } ) ;
245244 using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
246245 {
247246 int hiddenStateIndex = metadata . Outputs . Count - 2 ;
248247 inferenceParameters . AddInputTensor ( inputTensor ) ;
249248 inferenceParameters . AddOutputBuffer ( new [ ] { 1 , _tokenizer2 . TokenizerLength } ) ;
250- inferenceParameters . AddOutputBuffer ( hiddenStateIndex , new [ ] { 1 , tokenizedInput . Length , _tokenizer2 . TokenizerLength } ) ;
249+ inferenceParameters . AddOutputBuffer ( hiddenStateIndex , new [ ] { 1 , tokenizedInput . InputIds . Length , _tokenizer2 . TokenizerLength } ) ;
251250
252251 var results = await _textEncoder2 . RunInferenceAsync ( inferenceParameters ) ;
253252 var promptEmbeds = results . Last ( ) . ToDenseTensor ( ) ;
@@ -263,30 +262,43 @@ private async Task<EncoderResult> EncodeTokensAsync(long[] tokenizedInput)
263262 /// <param name="inputTokens">The input tokens.</param>
264263 /// <param name="minimumLength">The minimum length.</param>
265264 /// <returns></returns>
266- private async Task < PromptEmbeddingsResult > GenerateEmbedsAsync ( long [ ] inputTokens , int minimumLength )
265+ private async Task < PromptEmbeddingsResult > GenerateEmbedsAsync ( TokenizerResult inputTokens , int minimumLength )
267266 {
268267 // If less than minimumLength pad with blank tokens
269- if ( inputTokens . Length < minimumLength )
270- inputTokens = PadWithBlankTokens ( inputTokens , minimumLength ) . ToArray ( ) ;
268+ if ( inputTokens . InputIds . Length < minimumLength )
269+ {
270+ inputTokens . InputIds = PadWithBlankTokens ( inputTokens . InputIds , minimumLength , _tokenizer . PadTokenId ) . ToArray ( ) ;
271+ inputTokens . AttentionMask = PadWithBlankTokens ( inputTokens . AttentionMask , minimumLength , 1 ) . ToArray ( ) ;
272+ }
271273
272274 // The CLIP tokenizer only supports 77 tokens, batch process in groups of 77 and concatenate1
273- var embeddings = new List < float > ( ) ;
274- var pooledEmbeds = new List < float > ( ) ;
275- foreach ( var tokenBatch in inputTokens . Batch ( _tokenizer2 . TokenizerLimit ) )
276- {
277- var tokens = PadWithBlankTokens ( tokenBatch , _tokenizer2 . TokenizerLimit ) ;
278- var result = await EncodeTokensAsync ( tokens . ToArray ( ) ) ;
275+ var tokenBatches = new List < long [ ] > ( ) ;
276+ var attentionBatches = new List < long [ ] > ( ) ;
277+ foreach ( var tokenBatch in inputTokens . InputIds . Batch ( _tokenizer . TokenizerLimit ) )
278+ tokenBatches . Add ( PadWithBlankTokens ( tokenBatch , _tokenizer . TokenizerLimit , _tokenizer . PadTokenId ) . ToArray ( ) ) ;
279+ foreach ( var attentionBatch in inputTokens . AttentionMask . Batch ( _tokenizer . TokenizerLimit ) )
280+ attentionBatches . Add ( PadWithBlankTokens ( attentionBatch , _tokenizer . TokenizerLimit , 1 ) . ToArray ( ) ) ;
281+
279282
280- embeddings . AddRange ( result . PromptEmbeds ) ;
281- pooledEmbeds . AddRange ( result . PooledPromptEmbeds ) ;
283+ var promptEmbeddings = new List < float > ( ) ;
284+ var pooledPromptEmbeddings = new List < float > ( ) ;
285+ for ( int i = 0 ; i < tokenBatches . Count ; i ++ )
286+ {
287+ var result = await EncodeTokensAsync ( new TokenizerResult ( tokenBatches [ i ] , attentionBatches [ i ] ) ) ;
288+ promptEmbeddings . AddRange ( result . PromptEmbeds ) ;
289+ pooledPromptEmbeddings . AddRange ( result . PooledPromptEmbeds ) ;
282290 }
291+
292+
293+ //var embeddingsDim = new[] { 1, promptEmbeddings.Count / _tokenizer2.TokenizerLength, _tokenizer2.TokenizerLength };
294+ //var promptTensor = new DenseTensor<float>(promptEmbeddings.ToArray(), embeddingsDim);
283295
284- var embeddingsDim = new [ ] { 1 , embeddings . Count / _tokenizer2 . TokenizerLength , _tokenizer2 . TokenizerLength } ;
285- var promptTensor = new DenseTensor < float > ( embeddings . ToArray ( ) , embeddingsDim ) ;
296+ ////TODO: Pooled embeds do not support more than 77 tokens, just grab first set
297+ //var pooledDim = new[] { 1, _tokenizer2.TokenizerLength };
298+ //var pooledTensor = new DenseTensor<float>(pooledPromptEmbeddings.Take(_tokenizer2.TokenizerLength).ToArray(), pooledDim);
286299
287- //TODO: Pooled embeds do not support more than 77 tokens, just grab first set
288- var pooledDim = new [ ] { 1 , _tokenizer2 . TokenizerLength } ;
289- var pooledTensor = new DenseTensor < float > ( pooledEmbeds . Take ( _tokenizer2 . TokenizerLength ) . ToArray ( ) , pooledDim ) ;
300+ var promptTensor = new DenseTensor < float > ( promptEmbeddings . ToArray ( ) , new [ ] { 1 , promptEmbeddings . Count / _tokenizer2 . TokenizerLength , _tokenizer2 . TokenizerLength } ) ;
301+ var pooledTensor = new DenseTensor < float > ( pooledPromptEmbeddings . ToArray ( ) , new [ ] { 1 , pooledPromptEmbeddings . Count } ) ;
290302 return new PromptEmbeddingsResult ( promptTensor , pooledTensor ) ;
291303 }
292304
@@ -297,11 +309,11 @@ private async Task<PromptEmbeddingsResult> GenerateEmbedsAsync(long[] inputToken
297309 /// <param name="inputs">The inputs.</param>
298310 /// <param name="requiredLength">Length of the required.</param>
299311 /// <returns></returns>
300- private IEnumerable < long > PadWithBlankTokens ( IEnumerable < long > inputs , int requiredLength )
312+ private IEnumerable < long > PadWithBlankTokens ( IEnumerable < long > inputs , int requiredLength , int padTokenId )
301313 {
302314 var count = inputs . Count ( ) ;
303315 if ( requiredLength > count )
304- return inputs . Concat ( Enumerable . Repeat ( ( long ) _tokenizer . PadTokenId , requiredLength - count ) ) ;
316+ return inputs . Concat ( Enumerable . Repeat ( ( long ) padTokenId , requiredLength - count ) ) ;
305317 return inputs ;
306318 }
307319
0 commit comments