99using OnnxStack . StableDiffusion . Enums ;
1010using OnnxStack . StableDiffusion . Helpers ;
1111using OnnxStack . StableDiffusion . Models ;
12+ using OnnxStack . StableDiffusion . Schedulers . StableDiffusion ;
1213using System ;
1314using System . Collections . Generic ;
1415using System . Diagnostics ;
1516using System . Linq ;
1617using System . Threading ;
1718using System . Threading . Tasks ;
1819
19- namespace OnnxStack . StableDiffusion . Diffusers . ControlNet
20+ namespace OnnxStack . StableDiffusion . Diffusers . StableDiffusion
2021{
21- public sealed class TextDiffuser : ControlNetDiffuser
22+ public class ControlNetDiffuser : DiffuserBase
2223 {
2324 /// <summary>
24- /// Initializes a new instance of the <see cref="TextDiffuser "/> class.
25+ /// Initializes a new instance of the <see cref="ControlNetDiffuser "/> class.
2526 /// </summary>
2627 /// <param name="configuration">The configuration.</param>
2728 /// <param name="onnxModelService">The onnx model service.</param>
28- public TextDiffuser ( IOnnxModelService onnxModelService , IPromptService promptService , ILogger < TextDiffuser > logger )
29- : base ( onnxModelService , promptService , logger )
30- {
31- }
29+ public ControlNetDiffuser ( IOnnxModelService onnxModelService , IPromptService promptService , ILogger < ControlNetDiffuser > logger )
30+ : base ( onnxModelService , promptService , logger ) { }
31+
32+
33+ /// <summary>
34+ /// Gets the type of the pipeline.
35+ /// </summary>
36+ public override DiffuserPipelineType PipelineType => DiffuserPipelineType . StableDiffusion ;
3237
3338
3439 /// <summary>
3540 /// Gets the type of the diffuser.
3641 /// </summary>
37- public override DiffuserType DiffuserType => DiffuserType . ImageToImage ;
42+ public override DiffuserType DiffuserType => DiffuserType . ControlNet ;
43+
3844
3945 /// <summary>
4046 /// Called on each Scheduler step.
@@ -47,7 +53,7 @@ public TextDiffuser(IOnnxModelService onnxModelService, IPromptService promptSer
4753 /// <param name="progressCallback">The progress callback.</param>
4854 /// <param name="cancellationToken">The cancellation token.</param>
4955 /// <returns></returns>
50- /// <exception cref="System. NotImplementedException"></exception>
56+ /// <exception cref="NotImplementedException"></exception>
5157 protected override async Task < DenseTensor < float > > SchedulerStepAsync ( StableDiffusionModelSet modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
5258 {
5359 // Get Scheduler
@@ -63,10 +69,10 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
6369 var metadata = _onnxModelService . GetModelMetadata ( modelOptions , OnnxModelType . Unet ) ;
6470
6571 // Get Model metadata
66- var controlNetMetadata = _onnxModelService . GetModelMetadata ( modelOptions , OnnxModelType . Control ) ;
67-
72+ var controlNetMetadata = _onnxModelService . GetModelMetadata ( modelOptions , OnnxModelType . ControlNet ) ;
73+
6874 // Control Image
69- var controlImage = promptOptions . InputImage . ToDenseTensor ( new [ ] { 1 , 3 , schedulerOptions . Height , schedulerOptions . Width } , false ) ;
75+ var controlImage = PrepareControlImage ( promptOptions , schedulerOptions ) ;
7076
7177 // Loop though the timesteps
7278 var step = 0 ;
@@ -98,14 +104,15 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
98104 controlNetParameters . AddInputTensor ( timestepTensor ) ;
99105 controlNetParameters . AddInputTensor ( promptEmbeddings . PromptEmbeds ) ;
100106 controlNetParameters . AddInputTensor ( controlImage ) ;
101- controlNetParameters . AddInputTensor ( conditioningScale ) ;
107+ if ( controlNetMetadata . Inputs . Count == 5 )
108+ controlNetParameters . AddInputTensor ( conditioningScale ) ;
102109
103110 // Optimization: Pre-allocate device buffers for inputs
104111 foreach ( var item in controlNetMetadata . Outputs )
105112 controlNetParameters . AddOutputBuffer ( ) ;
106113
107114 // ControlNet inference
108- var controlNetResults = _onnxModelService . RunInference ( modelOptions , OnnxModelType . Control , controlNetParameters ) ;
115+ var controlNetResults = _onnxModelService . RunInference ( modelOptions , OnnxModelType . ControlNet , controlNetParameters ) ;
109116
110117 // Add ControlNet outputs to Unet input
111118 foreach ( var item in controlNetResults )
@@ -139,14 +146,75 @@ protected override async Task<DenseTensor<float>> SchedulerStepAsync(StableDiffu
139146 }
140147 }
141148
149+
150+ /// <summary>
151+ /// Gets the timesteps.
152+ /// </summary>
153+ /// <param name="options">The options.</param>
154+ /// <param name="scheduler">The scheduler.</param>
155+ /// <returns></returns>
142156 protected override IReadOnlyList < int > GetTimesteps ( SchedulerOptions options , IScheduler scheduler )
143157 {
144158 return scheduler . Timesteps ;
145159 }
146160
161+
162+ /// <summary>
163+ /// Prepares the input latents.
164+ /// </summary>
165+ /// <param name="model">The model.</param>
166+ /// <param name="prompt">The prompt.</param>
167+ /// <param name="options">The options.</param>
168+ /// <param name="scheduler">The scheduler.</param>
169+ /// <param name="timesteps">The timesteps.</param>
170+ /// <returns></returns>
147171 protected override Task < DenseTensor < float > > PrepareLatentsAsync ( StableDiffusionModelSet model , PromptOptions prompt , SchedulerOptions options , IScheduler scheduler , IReadOnlyList < int > timesteps )
148172 {
149173 return Task . FromResult ( scheduler . CreateRandomSample ( options . GetScaledDimension ( ) , scheduler . InitNoiseSigma ) ) ;
150174 }
175+
176+
177+ /// <summary>
178+ /// Creates the Conditioning Scale tensor.
179+ /// </summary>
180+ /// <param name="conditioningScale">The conditioningScale.</param>
181+ /// <returns></returns>
182+ protected static DenseTensor < double > CreateConditioningScaleTensor ( float conditioningScale )
183+ {
184+ return TensorHelper . CreateTensor ( new double [ ] { conditioningScale } , new int [ ] { 1 } ) ;
185+ }
186+
187+
188+ /// <summary>
189+ /// Prepares the control image.
190+ /// </summary>
191+ /// <param name="promptOptions">The prompt options.</param>
192+ /// <param name="schedulerOptions">The scheduler options.</param>
193+ /// <returns></returns>
194+ protected DenseTensor < float > PrepareControlImage ( PromptOptions promptOptions , SchedulerOptions schedulerOptions )
195+ {
196+ return promptOptions . InputImage . ToDenseTensor ( new [ ] { 1 , 3 , schedulerOptions . Height , schedulerOptions . Width } , false ) ;
197+ }
198+
199+
200+ /// <summary>
201+ /// Gets the scheduler.
202+ /// </summary>
203+ /// <param name="options">The options.</param>
204+ /// <param name="schedulerConfig">The scheduler configuration.</param>
205+ /// <returns></returns>
206+ protected override IScheduler GetScheduler ( SchedulerOptions options )
207+ {
208+ return options . SchedulerType switch
209+ {
210+ SchedulerType . LMS => new LMSScheduler ( options ) ,
211+ SchedulerType . Euler => new EulerScheduler ( options ) ,
212+ SchedulerType . EulerAncestral => new EulerAncestralScheduler ( options ) ,
213+ SchedulerType . DDPM => new DDPMScheduler ( options ) ,
214+ SchedulerType . DDIM => new DDIMScheduler ( options ) ,
215+ SchedulerType . KDPM2 => new KDPM2Scheduler ( options ) ,
216+ _ => default
217+ } ;
218+ }
151219 }
152220}
0 commit comments