11using Microsoft . Extensions . Logging ;
2- using Microsoft . ML . OnnxRuntime ;
32using Microsoft . ML . OnnxRuntime . Tensors ;
3+ using OnnxStack . Core ;
44using OnnxStack . Core . Config ;
5+ using OnnxStack . Core . Model ;
56using OnnxStack . Core . Services ;
67using OnnxStack . StableDiffusion . Common ;
78using OnnxStack . StableDiffusion . Config ;
89using OnnxStack . StableDiffusion . Enums ;
910using OnnxStack . StableDiffusion . Helpers ;
1011using SixLabors . ImageSharp ;
12+ using SixLabors . ImageSharp . Processing ;
1113using System ;
1214using System . Collections . Generic ;
15+ using System . Diagnostics ;
1316using System . Linq ;
17+ using System . Threading ;
18+ using System . Threading . Tasks ;
1419
1520namespace OnnxStack . StableDiffusion . Diffusers . LatentConsistency
1621{
@@ -40,7 +45,7 @@ public InpaintLegacyDiffuser(IOnnxModelService onnxModelService, IPromptService
4045 /// <param name="options">The options.</param>
4146 /// <param name="scheduler">The scheduler.</param>
4247 /// <returns></returns>
43- protected override IReadOnlyList < int > GetTimesteps ( PromptOptions prompt , SchedulerOptions options , IScheduler scheduler )
48+ protected override IReadOnlyList < int > GetTimesteps ( SchedulerOptions options , IScheduler scheduler )
4449 {
4550 // Image2Image we narrow step the range by the Strength
4651 var inittimestep = Math . Min ( ( int ) ( options . InferenceSteps * options . Strength ) , options . InferenceSteps ) ;
@@ -50,31 +55,202 @@ protected override IReadOnlyList<int> GetTimesteps(PromptOptions prompt, Schedul
5055
5156
5257 /// <summary>
53- /// Prepares the latents for inference .
58+ /// Runs the scheduler steps .
5459 /// </summary>
60+ /// <param name="modelOptions">The model options.</param>
61+ /// <param name="promptOptions">The prompt options.</param>
62+ /// <param name="schedulerOptions">The scheduler options.</param>
63+ /// <param name="promptEmbeddings">The prompt embeddings.</param>
64+ /// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
65+ /// <param name="progressCallback">The progress callback.</param>
66+ /// <param name="cancellationToken">The cancellation token.</param>
67+ /// <returns></returns>
68+ protected override async Task < DenseTensor < float > > SchedulerStepAsync ( IModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , DenseTensor < float > promptEmbeddings , bool performGuidance , Action < int , int > progressCallback = null , CancellationToken cancellationToken = default )
69+ {
70+ using ( var scheduler = GetScheduler ( schedulerOptions ) )
71+ {
72+ // Get timesteps
73+ var timesteps = GetTimesteps ( schedulerOptions , scheduler ) ;
74+
75+ // Create latent sample
76+ var latentsOriginal = await PrepareLatentsAsync ( modelOptions , promptOptions , schedulerOptions , scheduler , timesteps ) ;
77+
78+ // Create masks sample
79+ var maskImage = PrepareMask ( modelOptions , promptOptions , schedulerOptions ) ;
80+
81+ // Generate some noise
82+ var noise = scheduler . CreateRandomSample ( latentsOriginal . Dimensions ) ;
83+
84+ // Add noise to original latent
85+ var latents = scheduler . AddNoise ( latentsOriginal , noise , timesteps ) ;
86+
87+ // Get Model metadata
88+ var metadata = _onnxModelService . GetModelMetadata ( modelOptions , OnnxModelType . Unet ) ;
89+
90+ // Get Guidance Scale Embedding
91+ var guidanceEmbeddings = GetGuidanceScaleEmbedding ( schedulerOptions . GuidanceScale ) ;
92+
93+ // Denoised result
94+ DenseTensor < float > denoised = null ;
95+
96+ // Loop though the timesteps
97+ var step = 0 ;
98+ foreach ( var timestep in timesteps )
99+ {
100+ step ++ ;
101+ var stepTime = Stopwatch . GetTimestamp ( ) ;
102+ cancellationToken . ThrowIfCancellationRequested ( ) ;
103+
104+ // Create input tensor.
105+ var inputTensor = scheduler . ScaleInput ( latents , timestep ) ;
106+ var timestepTensor = CreateTimestepTensor ( timestep ) ;
107+
108+ var outputChannels = 1 ;
109+ var outputDimension = schedulerOptions . GetScaledDimension ( outputChannels ) ;
110+ using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
111+ {
112+ inferenceParameters . AddInputTensor ( inputTensor ) ;
113+ inferenceParameters . AddInputTensor ( timestepTensor ) ;
114+ inferenceParameters . AddInputTensor ( promptEmbeddings ) ;
115+ inferenceParameters . AddInputTensor ( guidanceEmbeddings ) ;
116+ inferenceParameters . AddOutputBuffer ( outputDimension ) ;
117+
118+ var results = await _onnxModelService . RunInferenceAsync ( modelOptions , OnnxModelType . Unet , inferenceParameters ) ;
119+ using ( var result = results . First ( ) )
120+ {
121+ var noisePred = result . ToDenseTensor ( ) ;
122+
123+ // Scheduler Step
124+ var schedulerResult = scheduler . Step ( noisePred , timestep , latents ) ;
125+
126+ latents = schedulerResult . Result ;
127+ denoised = schedulerResult . SampleData ;
128+
129+ // Add noise to original latent
130+ if ( step < timesteps . Count - 1 )
131+ {
132+ var noiseTimestep = timesteps [ step + 1 ] ;
133+ var initLatentsProper = scheduler . AddNoise ( latentsOriginal , noise , new [ ] { noiseTimestep } ) ;
134+
135+ // Apply mask and combine
136+ latents = ApplyMaskedLatents ( schedulerResult . Result , initLatentsProper , maskImage ) ;
137+ }
138+ }
139+ }
140+
141+ progressCallback ? . Invoke ( step , timesteps . Count ) ;
142+ _logger ? . LogEnd ( $ "Step { step } /{ timesteps . Count } ", stepTime ) ;
143+ }
144+
145+ // Decode Latents
146+ return await DecodeLatentsAsync ( modelOptions , promptOptions , schedulerOptions , denoised ) ;
147+ }
148+ }
149+
150+
151+ /// <summary>
152+ /// Prepares the input latents for inference.
153+ /// </summary>
154+ /// <param name="model">The model.</param>
55155 /// <param name="prompt">The prompt.</param>
56156 /// <param name="options">The options.</param>
57157 /// <param name="scheduler">The scheduler.</param>
158+ /// <param name="timesteps">The timesteps.</param>
58159 /// <returns></returns>
59- protected override DenseTensor < float > PrepareLatents ( IModelOptions model , PromptOptions prompt , SchedulerOptions options , IScheduler scheduler , IReadOnlyList < int > timesteps )
160+ protected override async Task < DenseTensor < float > > PrepareLatentsAsync ( IModelOptions model , PromptOptions prompt , SchedulerOptions options , IScheduler scheduler , IReadOnlyList < int > timesteps )
60161 {
61162 // Image input, decode, add noise, return as latent 0
62163 var imageTensor = prompt . InputImage . ToDenseTensor ( new [ ] { 1 , 3 , options . Height , options . Width } ) ;
63- var inputNames = _onnxModelService . GetInputNames ( model , OnnxModelType . VaeEncoder ) ;
64- var inputParameters = CreateInputParameters ( NamedOnnxValue . CreateFromTensor ( inputNames [ 0 ] , imageTensor ) ) ;
65- using ( var inferResult = _onnxModelService . RunInference ( model , OnnxModelType . VaeEncoder , inputParameters ) )
164+
165+ //TODO: Model Config, Channels
166+ var outputDimensions = options . GetScaledDimension ( ) ;
167+ var metadata = _onnxModelService . GetModelMetadata ( model , OnnxModelType . VaeEncoder ) ;
168+ using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
66169 {
67- var sample = inferResult . FirstElementAs < DenseTensor < float > > ( ) ;
68- var scaledSample = sample
69- . Add ( scheduler . CreateRandomSample ( sample . Dimensions , options . InitialNoiseLevel ) )
70- . MultiplyBy ( model . ScaleFactor ) ;
170+ inferenceParameters . AddInputTensor ( imageTensor ) ;
171+ inferenceParameters . AddOutputBuffer ( outputDimensions ) ;
172+
173+ var results = await _onnxModelService . RunInferenceAsync ( model , OnnxModelType . VaeEncoder , inferenceParameters ) ;
174+ using ( var result = results . First ( ) )
175+ {
176+ var outputResult = result . ToDenseTensor ( ) ;
177+ var scaledSample = outputResult
178+ . Add ( scheduler . CreateRandomSample ( outputDimensions , options . InitialNoiseLevel ) )
179+ . MultiplyBy ( model . ScaleFactor ) ;
180+
181+ return scaledSample ;
182+ }
183+ }
184+ }
185+
71186
72- var noisySample = scheduler . AddNoise ( scaledSample , scheduler . CreateRandomSample ( scaledSample . Dimensions ) , timesteps ) ;
73- if ( prompt . BatchCount > 1 )
74- return noisySample . Repeat ( prompt . BatchCount ) ;
187+ /// <summary>
188+ /// Prepares the mask.
189+ /// </summary>
190+ /// <param name="promptOptions">The prompt options.</param>
191+ /// <param name="schedulerOptions">The scheduler options.</param>
192+ /// <returns></returns>
193+ private DenseTensor < float > PrepareMask ( IModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions )
194+ {
195+ using ( var mask = promptOptions . InputImageMask . ToImage ( ) )
196+ {
197+ // Prepare the mask
198+ int width = schedulerOptions . GetScaledWidth ( ) ;
199+ int height = schedulerOptions . GetScaledHeight ( ) ;
200+ mask . Mutate ( x => x . Grayscale ( ) ) ;
201+ mask . Mutate ( x => x . Resize ( new Size ( width , height ) , KnownResamplers . NearestNeighbor , true ) ) ;
202+ var maskTensor = new DenseTensor < float > ( new [ ] { 1 , 4 , width , height } ) ;
203+ mask . ProcessPixelRows ( img =>
204+ {
205+ for ( int x = 0 ; x < width ; x ++ )
206+ {
207+ for ( int y = 0 ; y < height ; y ++ )
208+ {
209+ var pixelSpan = img . GetRowSpan ( y ) ;
210+ var value = pixelSpan [ x ] . A / 255.0f ;
211+ maskTensor [ 0 , 0 , y , x ] = 1f - value ;
212+ maskTensor [ 0 , 1 , y , x ] = 0f ; // Needed for shape only
213+ maskTensor [ 0 , 2 , y , x ] = 0f ; // Needed for shape only
214+ maskTensor [ 0 , 3 , y , x ] = 0f ; // Needed for shape only
215+ }
216+ }
217+ } ) ;
218+
219+ return maskTensor ;
220+ }
221+ }
222+
223+
224+ /// <summary>
225+ /// Applies the masked latents.
226+ /// </summary>
227+ /// <param name="latents">The latents.</param>
228+ /// <param name="initLatentsProper">The initialize latents proper.</param>
229+ /// <param name="mask">The mask.</param>
230+ /// <returns></returns>
231+ private DenseTensor < float > ApplyMaskedLatents ( DenseTensor < float > latents , DenseTensor < float > initLatentsProper , DenseTensor < float > mask )
232+ {
233+ var result = new DenseTensor < float > ( latents . Dimensions ) ;
234+ for ( int batch = 0 ; batch < latents . Dimensions [ 0 ] ; batch ++ )
235+ {
236+ for ( int channel = 0 ; channel < latents . Dimensions [ 1 ] ; channel ++ )
237+ {
238+ for ( int height = 0 ; height < latents . Dimensions [ 2 ] ; height ++ )
239+ {
240+ for ( int width = 0 ; width < latents . Dimensions [ 3 ] ; width ++ )
241+ {
242+ float maskValue = mask [ batch , 0 , height , width ] ;
243+ float latentsValue = latents [ batch , channel , height , width ] ;
244+ float initLatentsProperValue = initLatentsProper [ batch , channel , height , width ] ;
75245
76- return noisySample ;
246+ //Apply the logic to compute the result based on the mask
247+ float newValue = initLatentsProperValue * maskValue + latentsValue * ( 1f - maskValue ) ;
248+ result [ batch , channel , height , width ] = newValue ;
249+ }
250+ }
251+ }
77252 }
253+ return result ;
78254 }
79255 }
80256}
0 commit comments