Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit ff45ed3

Browse files
committed
Update PromptService to new OrtValue API
1 parent 29e2973 commit ff45ed3

File tree

1 file changed

+26
-16
lines changed

1 file changed

+26
-16
lines changed

OnnxStack.StableDiffusion/Services/PromptService.cs

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ public async Task<DenseTensor<float>> CreatePromptAsync(IModelOptions model, Pro
4747
var negativePromptEmbeddings = await GenerateEmbedsAsync(model, negativePromptTokens, maxPromptTokenCount);
4848

4949
// If we have a batch, repeat the prompt embeddings
50-
if(promptOptions.BatchCount > 1)
50+
if (promptOptions.BatchCount > 1)
5151
{
5252
promptEmbeddings = promptEmbeddings.Repeat(promptOptions.BatchCount);
5353
negativePromptEmbeddings = negativePromptEmbeddings.Repeat(promptOptions.BatchCount);
@@ -67,21 +67,24 @@ public async Task<DenseTensor<float>> CreatePromptAsync(IModelOptions model, Pro
6767
/// </summary>
6868
/// <param name="inputText">The input text.</param>
6969
/// <returns>Tokens generated for the specified text input</returns>
70-
public async Task<int[]> DecodeTextAsync(IModelOptions model, string inputText)
70+
public Task<int[]> DecodeTextAsync(IModelOptions model, string inputText)
7171
{
7272
if (string.IsNullOrEmpty(inputText))
73-
return Array.Empty<int>();
73+
return Task.FromResult(Array.Empty<int>());
7474

75-
// Create input tensor.
7675
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.Tokenizer);
76+
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.Tokenizer);
7777
var inputTensor = new DenseTensor<string>(new string[] { inputText }, new int[] { 1 });
78-
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor));
79-
80-
// Run inference.
81-
using (var inferResult = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.Tokenizer, inputParameters))
78+
using (var inputTensorValue = OrtValue.CreateFromStringTensor(inputTensor))
8279
{
83-
var resultTensor = inferResult.FirstElementAs<DenseTensor<long>>();
84-
return resultTensor.Select(x => (int)x).ToArray();
80+
var outputs = new string[] { outputNames[0] };
81+
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
82+
var results = _onnxModelService.RunInference(model, OnnxModelType.Tokenizer, inputs, outputs);
83+
using (var result = results.First())
84+
{
85+
var resultData = result.GetTensorDataAsSpan<long>().ToArray();
86+
return Task.FromResult(Array.ConvertAll(resultData, Convert.ToInt32));
87+
}
8588
}
8689
}
8790

@@ -95,14 +98,21 @@ public async Task<float[]> EncodeTokensAsync(IModelOptions model, int[] tokenize
9598
{
9699
// Create input tensor.
97100
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.TextEncoder);
98-
var inputTensor = TensorHelper.CreateTensor(tokenizedInput, new[] { 1, tokenizedInput.Length });
99-
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor));
101+
var outputNames = _onnxModelService.GetOutputNames(model, OnnxModelType.TextEncoder);
100102

101-
// Run inference.
102-
using (var inferResult = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.TextEncoder, inputParameters))
103+
var inputDim = new[] { 1L, tokenizedInput.Length };
104+
var outputDim = new[] { 1L, tokenizedInput.Length, model.EmbeddingsLength };
105+
var outputBuffer = new float[outputDim.GetBufferLength()];
106+
using (var inputTensorValue = OrtValue.CreateTensorValueFromMemory(tokenizedInput, inputDim))
107+
using (var outputTensorValue = OrtValue.CreateTensorValueFromMemory(outputBuffer, outputDim))
103108
{
104-
var resultTensor = inferResult.FirstElementAs<DenseTensor<float>>();
105-
return resultTensor.ToArray();
109+
var inputs = new Dictionary<string, OrtValue> { { inputNames[0], inputTensorValue } };
110+
var outputs = new Dictionary<string, OrtValue> { { outputNames[0], outputTensorValue } };
111+
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.TextEncoder, inputs, outputs);
112+
using (var result = results.First())
113+
{
114+
return outputBuffer;
115+
}
106116
}
107117
}
108118

0 commit comments

Comments
 (0)