@@ -47,7 +47,7 @@ public async Task<DenseTensor<float>> CreatePromptAsync(IModelOptions model, Pro
4747 var negativePromptEmbeddings = await GenerateEmbedsAsync ( model , negativePromptTokens , maxPromptTokenCount ) ;
4848
4949 // If we have a batch, repeat the prompt embeddings
50- if ( promptOptions . BatchCount > 1 )
50+ if ( promptOptions . BatchCount > 1 )
5151 {
5252 promptEmbeddings = promptEmbeddings . Repeat ( promptOptions . BatchCount ) ;
5353 negativePromptEmbeddings = negativePromptEmbeddings . Repeat ( promptOptions . BatchCount ) ;
@@ -67,21 +67,24 @@ public async Task<DenseTensor<float>> CreatePromptAsync(IModelOptions model, Pro
6767 /// </summary>
6868 /// <param name="inputText">The input text.</param>
6969 /// <returns>Tokens generated for the specified text input</returns>
70- public async Task < int [ ] > DecodeTextAsync ( IModelOptions model , string inputText )
70+ public Task < int [ ] > DecodeTextAsync ( IModelOptions model , string inputText )
7171 {
7272 if ( string . IsNullOrEmpty ( inputText ) )
73- return Array . Empty < int > ( ) ;
73+ return Task . FromResult ( Array . Empty < int > ( ) ) ;
7474
75- // Create input tensor.
7675 var inputNames = _onnxModelService . GetInputNames ( model , OnnxModelType . Tokenizer ) ;
76+ var outputNames = _onnxModelService . GetOutputNames ( model , OnnxModelType . Tokenizer ) ;
7777 var inputTensor = new DenseTensor < string > ( new string [ ] { inputText } , new int [ ] { 1 } ) ;
78- var inputParameters = CreateInputParameters ( NamedOnnxValue . CreateFromTensor ( inputNames [ 0 ] , inputTensor ) ) ;
79-
80- // Run inference.
81- using ( var inferResult = await _onnxModelService . RunInferenceAsync ( model , OnnxModelType . Tokenizer , inputParameters ) )
78+ using ( var inputTensorValue = OrtValue . CreateFromStringTensor ( inputTensor ) )
8279 {
83- var resultTensor = inferResult . FirstElementAs < DenseTensor < long > > ( ) ;
84- return resultTensor . Select ( x => ( int ) x ) . ToArray ( ) ;
80+ var outputs = new string [ ] { outputNames [ 0 ] } ;
81+ var inputs = new Dictionary < string , OrtValue > { { inputNames [ 0 ] , inputTensorValue } } ;
82+ var results = _onnxModelService . RunInference ( model , OnnxModelType . Tokenizer , inputs , outputs ) ;
83+ using ( var result = results . First ( ) )
84+ {
85+ var resultData = result . GetTensorDataAsSpan < long > ( ) . ToArray ( ) ;
86+ return Task . FromResult ( Array . ConvertAll ( resultData , Convert . ToInt32 ) ) ;
87+ }
8588 }
8689 }
8790
@@ -95,14 +98,21 @@ public async Task<float[]> EncodeTokensAsync(IModelOptions model, int[] tokenize
9598 {
9699 // Create input tensor.
97100 var inputNames = _onnxModelService . GetInputNames ( model , OnnxModelType . TextEncoder ) ;
98- var inputTensor = TensorHelper . CreateTensor ( tokenizedInput , new [ ] { 1 , tokenizedInput . Length } ) ;
99- var inputParameters = CreateInputParameters ( NamedOnnxValue . CreateFromTensor ( inputNames [ 0 ] , inputTensor ) ) ;
101+ var outputNames = _onnxModelService . GetOutputNames ( model , OnnxModelType . TextEncoder ) ;
100102
101- // Run inference.
102- using ( var inferResult = await _onnxModelService . RunInferenceAsync ( model , OnnxModelType . TextEncoder , inputParameters ) )
103+ var inputDim = new [ ] { 1L , tokenizedInput . Length } ;
104+ var outputDim = new [ ] { 1L , tokenizedInput . Length , model . EmbeddingsLength } ;
105+ var outputBuffer = new float [ outputDim . GetBufferLength ( ) ] ;
106+ using ( var inputTensorValue = OrtValue . CreateTensorValueFromMemory ( tokenizedInput , inputDim ) )
107+ using ( var outputTensorValue = OrtValue . CreateTensorValueFromMemory ( outputBuffer , outputDim ) )
103108 {
104- var resultTensor = inferResult . FirstElementAs < DenseTensor < float > > ( ) ;
105- return resultTensor . ToArray ( ) ;
109+ var inputs = new Dictionary < string , OrtValue > { { inputNames [ 0 ] , inputTensorValue } } ;
110+ var outputs = new Dictionary < string , OrtValue > { { outputNames [ 0 ] , outputTensorValue } } ;
111+ var results = await _onnxModelService . RunInferenceAsync ( model , OnnxModelType . TextEncoder , inputs , outputs ) ;
112+ using ( var result = results . First ( ) )
113+ {
114+ return outputBuffer ;
115+ }
106116 }
107117 }
108118
0 commit comments