1- using Microsoft . ML . OnnxRuntime ;
2- using Microsoft . ML . OnnxRuntime . Tensors ;
1+ using Microsoft . ML . OnnxRuntime . Tensors ;
32using OnnxStack . Core ;
43using OnnxStack . Core . Config ;
54using OnnxStack . Core . Model ;
65using OnnxStack . Core . Services ;
76using OnnxStack . StableDiffusion . Common ;
87using OnnxStack . StableDiffusion . Config ;
8+ using OnnxStack . StableDiffusion . Enums ;
99using OnnxStack . StableDiffusion . Helpers ;
1010using System ;
1111using System . Collections . Generic ;
@@ -40,6 +40,25 @@ public record EmbedsResult(DenseTensor<float> PromptEmbeds, DenseTensor<float> P
4040 /// <param name="negativePrompt">The negative prompt.</param>
4141 /// <returns>Tensor containing all text embeds generated from the prompt and negative prompt</returns>
4242 public async Task < PromptEmbeddingsResult > CreatePromptAsync ( IModelOptions model , PromptOptions promptOptions , bool isGuidanceEnabled )
43+ {
44+ return model . TokenizerType switch
45+ {
46+ TokenizerType . One => await CreateEmbedsOneAsync ( model , promptOptions , isGuidanceEnabled ) ,
47+ TokenizerType . Two => await CreateEmbedsTwoAsync ( model , promptOptions , isGuidanceEnabled ) ,
48+ TokenizerType . Both => await CreateEmbedsBothAsync ( model , promptOptions , isGuidanceEnabled ) ,
49+ _ => throw new ArgumentException ( "TokenizerType is not set" )
50+ } ;
51+ }
52+
53+
54+ /// <summary>
55+ /// Creates the embeds using Tokenizer and TextEncoder
56+ /// </summary>
57+ /// <param name="model">The model.</param>
58+ /// <param name="promptOptions">The prompt options.</param>
59+ /// <param name="isGuidanceEnabled">if set to <c>true</c> is guidance enabled.</param>
60+ /// <returns></returns>
61+ private async Task < PromptEmbeddingsResult > CreateEmbedsOneAsync ( IModelOptions model , PromptOptions promptOptions , bool isGuidanceEnabled )
4362 {
4463 // Tokenize Prompt and NegativePrompt
4564 var promptTokens = await DecodeTextAsIntAsync ( model , promptOptions . Prompt ) ;
@@ -50,31 +69,74 @@ public async Task<PromptEmbeddingsResult> CreatePromptAsync(IModelOptions model,
5069 var promptEmbeddings = await GenerateEmbedsAsync ( model , promptTokens , maxPromptTokenCount ) ;
5170 var negativePromptEmbeddings = await GenerateEmbedsAsync ( model , negativePromptTokens , maxPromptTokenCount ) ;
5271
53- if ( model . IsDualTokenizer )
54- {
55- /// Tokenize Prompt and NegativePrompt with Tokenizer2
56- var dualPromptTokens = await DecodeTextAsLongAsync ( model , promptOptions . Prompt ) ;
57- var dualNegativePromptTokens = await DecodeTextAsLongAsync ( model , promptOptions . NegativePrompt ) ;
72+ if ( isGuidanceEnabled )
73+ return new PromptEmbeddingsResult ( negativePromptEmbeddings . Concatenate ( promptEmbeddings ) ) ;
5874
59- // Generate embeds for tokens
60- var dualPromptEmbeddings = await GenerateEmbedsAsync ( model , dualPromptTokens , maxPromptTokenCount ) ;
61- var dualNegativePromptEmbeddings = await GenerateEmbedsAsync ( model , dualNegativePromptTokens , maxPromptTokenCount ) ;
75+ return new PromptEmbeddingsResult ( promptEmbeddings ) ;
76+ }
6277
63- var dualPrompt = promptEmbeddings . Concatenate ( dualPromptEmbeddings . PromptEmbeds , 2 ) ;
64- var dualNegativePrompt = negativePromptEmbeddings . Concatenate ( dualNegativePromptEmbeddings . PromptEmbeds , 2 ) ;
65- var pooledPromptEmbeds = dualPromptEmbeddings . PooledPromptEmbeds ;
66- var pooledNegativePromptEmbeds = dualNegativePromptEmbeddings . PooledPromptEmbeds ;
78+ /// <summary>
79+ /// Creates the embeds using Tokenizer2 and TextEncoder2
80+ /// </summary>
81+ /// <param name="model">The model.</param>
82+ /// <param name="promptOptions">The prompt options.</param>
83+ /// <param name="isGuidanceEnabled">if set to <c>true</c> is guidance enabled.</param>
84+ /// <returns></returns>
85+ private async Task < PromptEmbeddingsResult > CreateEmbedsTwoAsync ( IModelOptions model , PromptOptions promptOptions , bool isGuidanceEnabled )
86+ {
87+ /// Tokenize Prompt and NegativePrompt with Tokenizer2
88+ var promptTokens = await DecodeTextAsLongAsync ( model , promptOptions . Prompt ) ;
89+ var negativePromptTokens = await DecodeTextAsLongAsync ( model , promptOptions . NegativePrompt ) ;
90+ var maxPromptTokenCount = Math . Max ( promptTokens . Length , negativePromptTokens . Length ) ;
6791
68- if ( isGuidanceEnabled )
69- return new PromptEmbeddingsResult ( dualNegativePrompt . Concatenate ( dualPrompt ) , pooledNegativePromptEmbeds . Concatenate ( pooledPromptEmbeds ) ) ;
92+ // Generate embeds for tokens
93+ var promptEmbeddings = await GenerateEmbedsAsync ( model , promptTokens , maxPromptTokenCount ) ;
94+ var negativePromptEmbeddings = await GenerateEmbedsAsync ( model , negativePromptTokens , maxPromptTokenCount ) ;
7095
71- return new PromptEmbeddingsResult ( dualPrompt , pooledPromptEmbeds ) ;
72- }
96+ if ( isGuidanceEnabled )
97+ return new PromptEmbeddingsResult (
98+ negativePromptEmbeddings . PromptEmbeds . Concatenate ( promptEmbeddings . PromptEmbeds ) ,
99+ negativePromptEmbeddings . PooledPromptEmbeds . Concatenate ( promptEmbeddings . PooledPromptEmbeds ) ) ;
100+
101+ return new PromptEmbeddingsResult ( promptEmbeddings . PromptEmbeds , promptEmbeddings . PooledPromptEmbeds ) ;
102+ }
103+
104+
105+ /// <summary>
106+ /// Creates the embeds using Tokenizer, Tokenizer2, TextEncoder and TextEncoder2
107+ /// </summary>
108+ /// <param name="model">The model.</param>
109+ /// <param name="promptOptions">The prompt options.</param>
110+ /// <param name="isGuidanceEnabled">if set to <c>true</c> is guidance enabled.</param>
111+ /// <returns></returns>
112+ private async Task < PromptEmbeddingsResult > CreateEmbedsBothAsync ( IModelOptions model , PromptOptions promptOptions , bool isGuidanceEnabled )
113+ {
114+ // Tokenize Prompt and NegativePrompt
115+ var promptTokens = await DecodeTextAsIntAsync ( model , promptOptions . Prompt ) ;
116+ var negativePromptTokens = await DecodeTextAsIntAsync ( model , promptOptions . NegativePrompt ) ;
117+ var maxPromptTokenCount = Math . Max ( promptTokens . Length , negativePromptTokens . Length ) ;
118+
119+ // Generate embeds for tokens
120+ var promptEmbeddings = await GenerateEmbedsAsync ( model , promptTokens , maxPromptTokenCount ) ;
121+ var negativePromptEmbeddings = await GenerateEmbedsAsync ( model , negativePromptTokens , maxPromptTokenCount ) ;
122+
123+ /// Tokenize Prompt and NegativePrompt with Tokenizer2
124+ var dualPromptTokens = await DecodeTextAsLongAsync ( model , promptOptions . Prompt ) ;
125+ var dualNegativePromptTokens = await DecodeTextAsLongAsync ( model , promptOptions . NegativePrompt ) ;
126+
127+ // Generate embeds for tokens
128+ var dualPromptEmbeddings = await GenerateEmbedsAsync ( model , dualPromptTokens , maxPromptTokenCount ) ;
129+ var dualNegativePromptEmbeddings = await GenerateEmbedsAsync ( model , dualNegativePromptTokens , maxPromptTokenCount ) ;
130+
131+ var dualPrompt = promptEmbeddings . Concatenate ( dualPromptEmbeddings . PromptEmbeds , 2 ) ;
132+ var dualNegativePrompt = negativePromptEmbeddings . Concatenate ( dualNegativePromptEmbeddings . PromptEmbeds , 2 ) ;
133+ var pooledPromptEmbeds = dualPromptEmbeddings . PooledPromptEmbeds ;
134+ var pooledNegativePromptEmbeds = dualNegativePromptEmbeddings . PooledPromptEmbeds ;
73135
74136 if ( isGuidanceEnabled )
75- return new PromptEmbeddingsResult ( negativePromptEmbeddings . Concatenate ( promptEmbeddings ) ) ;
137+ return new PromptEmbeddingsResult ( dualNegativePrompt . Concatenate ( dualPrompt ) , pooledNegativePromptEmbeds . Concatenate ( pooledPromptEmbeds ) ) ;
76138
77- return new PromptEmbeddingsResult ( promptEmbeddings ) ;
139+ return new PromptEmbeddingsResult ( dualPrompt , pooledPromptEmbeds ) ;
78140 }
79141
80142
@@ -138,7 +200,7 @@ private Task<long[]> DecodeTextAsLongAsync(IModelOptions model, string inputText
138200 private async Task < float [ ] > EncodeTokensAsync ( IModelOptions model , int [ ] tokenizedInput )
139201 {
140202 var inputDim = new [ ] { 1 , tokenizedInput . Length } ;
141- var outputDim = new [ ] { 1 , tokenizedInput . Length , model . EmbeddingsLength } ;
203+ var outputDim = new [ ] { 1 , tokenizedInput . Length , model . TokenizerLength } ;
142204 var metadata = _onnxModelService . GetModelMetadata ( model , OnnxModelType . TextEncoder ) ;
143205 var inputTensor = new DenseTensor < int > ( tokenizedInput , inputDim ) ;
144206 using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
@@ -164,8 +226,8 @@ private async Task<float[]> EncodeTokensAsync(IModelOptions model, int[] tokeniz
164226 private async Task < EncoderResult > EncodeTokensAsync ( IModelOptions model , long [ ] tokenizedInput )
165227 {
166228 var inputDim = new [ ] { 1 , tokenizedInput . Length } ;
167- var promptOutputDim = new [ ] { 1 , tokenizedInput . Length , model . DualEmbeddingsLength } ;
168- var pooledOutputDim = new [ ] { 1 , model . DualEmbeddingsLength } ;
229+ var promptOutputDim = new [ ] { 1 , tokenizedInput . Length , model . Tokenizer2Length } ;
230+ var pooledOutputDim = new [ ] { 1 , model . Tokenizer2Length } ;
169231 var metadata = _onnxModelService . GetModelMetadata ( model , OnnxModelType . TextEncoder2 ) ;
170232 var inputTensor = new DenseTensor < long > ( tokenizedInput , inputDim ) ;
171233 using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
@@ -206,12 +268,12 @@ private async Task<EmbedsResult> GenerateEmbedsAsync(IModelOptions model, long[]
206268 pooledEmbeds . AddRange ( result . PooledPromptEmbeds ) ;
207269 }
208270
209- var embeddingsDim = new [ ] { 1 , embeddings . Count / model . DualEmbeddingsLength , model . DualEmbeddingsLength } ;
271+ var embeddingsDim = new [ ] { 1 , embeddings . Count / model . Tokenizer2Length , model . Tokenizer2Length } ;
210272 var promptTensor = TensorHelper . CreateTensor ( embeddings . ToArray ( ) , embeddingsDim ) ;
211273
212274 //TODO: Pooled embeds do not support more than 77 tokens, just grab first set
213- var pooledDim = new [ ] { 1 , model . DualEmbeddingsLength } ;
214- var pooledTensor = TensorHelper . CreateTensor ( pooledEmbeds . Take ( model . DualEmbeddingsLength ) . ToArray ( ) , pooledDim ) ;
275+ var pooledDim = new [ ] { 1 , model . Tokenizer2Length } ;
276+ var pooledTensor = TensorHelper . CreateTensor ( pooledEmbeds . Take ( model . Tokenizer2Length ) . ToArray ( ) , pooledDim ) ;
215277 return new EmbedsResult ( promptTensor , pooledTensor ) ;
216278 }
217279
@@ -236,7 +298,7 @@ private async Task<DenseTensor<float>> GenerateEmbedsAsync(IModelOptions model,
236298 embeddings . AddRange ( await EncodeTokensAsync ( model , tokens . ToArray ( ) ) ) ;
237299 }
238300
239- var dim = new [ ] { 1 , embeddings . Count / model . EmbeddingsLength , model . EmbeddingsLength } ;
301+ var dim = new [ ] { 1 , embeddings . Count / model . TokenizerLength , model . TokenizerLength } ;
240302 return TensorHelper . CreateTensor ( embeddings . ToArray ( ) , dim ) ;
241303 }
242304
0 commit comments