1313using System . Collections . Generic ;
1414using System . Diagnostics ;
1515using System . Linq ;
16- using System . Runtime . CompilerServices ;
1716using System . Threading ;
1817using System . Threading . Tasks ;
1918
@@ -63,7 +62,7 @@ public override Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOptions
6362 /// <param name="progressCallback">The progress callback.</param>
6463 /// <param name="cancellationToken">The cancellation token.</param>
6564 /// <returns></returns>
66- public override IAsyncEnumerable < BatchResult > DiffuseBatchAsync ( IModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < int , int , int , int > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
65+ public override IAsyncEnumerable < BatchResult > DiffuseBatchAsync ( IModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < int , int , int , int > progressCallback = null , CancellationToken cancellationToken = default )
6766 {
6867 // LCM does not support negative prompting
6968 promptOptions . NegativePrompt = string . Empty ;
@@ -104,6 +103,11 @@ protected override async Task<DenseTensor<float>> SchedulerStep(IModelOptions mo
104103 // Denoised result
105104 DenseTensor < float > denoised = null ;
106105
106+ // Get Model metadata
107+ var inputNames = _onnxModelService . GetInputNames ( modelOptions , OnnxModelType . Unet ) ;
108+ var outputNames = _onnxModelService . GetOutputNames ( modelOptions , OnnxModelType . Unet ) ;
109+ var inputMetaData = _onnxModelService . GetInputMetadata ( modelOptions , OnnxModelType . Unet ) ;
110+
107111 // Loop though the timesteps
108112 var step = 0 ;
109113 foreach ( var timestep in timesteps )
@@ -115,19 +119,33 @@ protected override async Task<DenseTensor<float>> SchedulerStep(IModelOptions mo
115119 // Create input tensor.
116120 var inputTensor = scheduler . ScaleInput ( latents , timestep ) ;
117121
118- // Create Input Parameters
119- var inputParameters = CreateUnetInputParams ( modelOptions , inputTensor , promptEmbeddings , guidanceEmbeddings , timestep ) ;
120-
121- // Run Inference
122- using ( var inferResult = await _onnxModelService . RunInferenceAsync ( modelOptions , OnnxModelType . Unet , inputParameters ) )
122+ var outputBuffer = new DenseTensor < float > ( schedulerOptions . GetScaledDimension ( ) ) ;
123+ using ( var outputTensorValue = outputBuffer . ToOrtValue ( ) )
124+ using ( var inputTensorValue = inputTensor . ToOrtValue ( ) )
125+ using ( var timestepOrtValue = CreateTimestepNamedOrtValue ( inputMetaData , inputNames [ 1 ] , timestep ) )
126+ using ( var promptTensorValue = promptEmbeddings . ToOrtValue ( ) )
127+ using ( var guidanceTensorValue = guidanceEmbeddings . ToOrtValue ( ) )
123128 {
124- var noisePred = inferResult . FirstElementAs < DenseTensor < float > > ( ) ;
125-
126- // Scheduler Step
127- var schedulerResult = scheduler . Step ( noisePred , timestep , latents ) ;
128-
129- latents = schedulerResult . Result ;
130- denoised = schedulerResult . SampleData ;
129+ var inputs = new Dictionary < string , OrtValue >
130+ {
131+ { inputNames [ 0 ] , inputTensorValue } ,
132+ { inputNames [ 1 ] , timestepOrtValue } ,
133+ { inputNames [ 2 ] , promptTensorValue } ,
134+ { inputNames [ 3 ] , guidanceTensorValue }
135+ } ;
136+
137+ var outputs = new Dictionary < string , OrtValue > { { outputNames [ 0 ] , outputTensorValue } } ;
138+ var results = await _onnxModelService . RunInferenceAsync ( modelOptions , OnnxModelType . Unet , inputs , outputs ) ;
139+ using ( var result = results . First ( ) )
140+ {
141+ var noisePred = outputBuffer ;
142+
143+ // Scheduler Step
144+ var schedulerResult = scheduler . Step ( noisePred , timestep , latents ) ;
145+
146+ latents = schedulerResult . Result ;
147+ denoised = schedulerResult . SampleData ;
148+ }
131149 }
132150
133151 progressCallback ? . Invoke ( step , timesteps . Count ) ;
@@ -140,27 +158,6 @@ protected override async Task<DenseTensor<float>> SchedulerStep(IModelOptions mo
140158 }
141159
142160
143- /// <summary>
144- /// Creates the Unet input parameters.
145- /// </summary>
146- /// <param name="model">The model.</param>
147- /// <param name="inputTensor">The input tensor.</param>
148- /// <param name="promptEmbeddings">The prompt embeddings.</param>
149- /// <param name="timestep">The timestep.</param>
150- /// <returns></returns>
151- protected IReadOnlyList < NamedOnnxValue > CreateUnetInputParams ( IModelOptions model , DenseTensor < float > inputTensor , DenseTensor < float > promptEmbeddings , DenseTensor < float > guidanceEmbeddings , int timestep )
152- {
153- var inputNames = _onnxModelService . GetInputNames ( model , OnnxModelType . Unet ) ;
154- var inputMetaData = _onnxModelService . GetInputMetadata ( model , OnnxModelType . Unet ) ;
155- var timestepNamedOnnxValue = CreateTimestepNamedOnnxValue ( inputMetaData , inputNames [ 1 ] , timestep ) ;
156- return CreateInputParameters (
157- NamedOnnxValue . CreateFromTensor ( inputNames [ 0 ] , inputTensor ) ,
158- timestepNamedOnnxValue ,
159- NamedOnnxValue . CreateFromTensor ( inputNames [ 2 ] , promptEmbeddings ) ,
160- NamedOnnxValue . CreateFromTensor ( inputNames [ 3 ] , guidanceEmbeddings ) ) ;
161- }
162-
163-
164161 /// <summary>
165162 /// Gets the scheduler.
166163 /// </summary>
0 commit comments