88using OnnxStack . StableDiffusion . Models ;
99using OnnxStack . StableDiffusion . Schedulers . StableDiffusion ;
1010using System ;
11+ using System . Collections . Generic ;
1112using System . Diagnostics ;
1213using System . Linq ;
1314using System . Threading ;
@@ -17,6 +18,9 @@ namespace OnnxStack.StableDiffusion.Diffusers.StableCascade
1718{
1819 public abstract class StableCascadeDiffuser : DiffuserBase
1920 {
21+ private readonly float _latentDimScale ;
22+ private readonly float _resolutionMultiple ;
23+ private readonly int _clipImageChannels ;
2024 private readonly UNetConditionModel _decoderUnet ;
2125
2226 /// <summary>
@@ -32,6 +36,9 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
3236 : base ( priorUnet , decoderVqgan , imageEncoder , memoryMode , logger )
3337 {
3438 _decoderUnet = decoderUnet ;
39+ _latentDimScale = 10.67f ;
40+ _resolutionMultiple = 42.67f ;
41+ _clipImageChannels = 768 ;
3542 }
3643
3744 /// <summary>
@@ -40,6 +47,32 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
4047 public override DiffuserPipelineType PipelineType => DiffuserPipelineType . StableCascade ;
4148
4249
50+ /// <summary>
51+ /// Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are
52+ /// height=24 and width = 24, the VQ latent shape needs to be height=int (24*10.67)=256 and
53+ /// width = int(24 * 10.67) = 256 in order to match the training conditions.
54+ /// </summary>
55+ protected float LatentDimScale => _latentDimScale ;
56+
57+
58+ /// <summary>
59+ /// Default resolution for multiple images generated
60+ /// </summary>
61+ protected float ResolutionMultiple => _resolutionMultiple ;
62+
63+
64+ /// <summary>
65+ /// Prepares the decoder latents.
66+ /// </summary>
67+ /// <param name="prompt">The prompt.</param>
68+ /// <param name="options">The options.</param>
69+ /// <param name="scheduler">The scheduler.</param>
70+ /// <param name="timesteps">The timesteps.</param>
71+ /// <param name="priorLatents">The prior latents.</param>
72+ /// <returns></returns>
73+ protected abstract Task < DenseTensor < float > > PrepareDecoderLatentsAsync ( PromptOptions prompt , SchedulerOptions options , IScheduler scheduler , IReadOnlyList < int > timesteps , DenseTensor < float > priorLatents ) ;
74+
75+
4376 /// <summary>
4477 /// Runs the scheduler steps.
4578 /// </summary>
@@ -52,27 +85,55 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
5285 /// <returns></returns>
5386 public override async Task < DenseTensor < float > > DiffuseAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
5487 {
88+ var decodeSchedulerOptions = schedulerOptions with
89+ {
90+ InferenceSteps = schedulerOptions . InferenceSteps2 ,
91+ GuidanceScale = schedulerOptions . GuidanceScale2
92+ } ;
93+
94+ var priorPromptEmbeddings = promptEmbeddings ;
95+ var decoderPromptEmbeddings = promptEmbeddings ;
96+ var priorPerformGuidance = schedulerOptions . GuidanceScale > 0 ;
97+ var decoderPerformGuidance = decodeSchedulerOptions . GuidanceScale > 0 ;
98+ if ( performGuidance )
99+ {
100+ if ( ! priorPerformGuidance )
101+ priorPromptEmbeddings = SplitPromptEmbeddings ( promptEmbeddings ) ;
102+ if ( ! decoderPerformGuidance )
103+ decoderPromptEmbeddings = SplitPromptEmbeddings ( promptEmbeddings ) ;
104+ }
105+
55106 // Prior Unet
56- var latentsPrior = await DiffusePriorAsync ( schedulerOptions , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
107+ var priorLatents = await DiffusePriorAsync ( promptOptions , schedulerOptions , priorPromptEmbeddings , priorPerformGuidance , progressCallback , cancellationToken ) ;
57108
58109 // Decoder Unet
59- var schedulerOptionsDecoder = schedulerOptions with { InferenceSteps = 10 , GuidanceScale = 0 } ;
60- var latents = await DiffuseDecodeAsync ( latentsPrior , schedulerOptionsDecoder , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
110+ var decoderLatents = await DiffuseDecodeAsync ( promptOptions , priorLatents , decodeSchedulerOptions , decoderPromptEmbeddings , decoderPerformGuidance , progressCallback , cancellationToken ) ;
61111
62112 // Decode Latents
63- return await DecodeLatentsAsync ( promptOptions , schedulerOptions , latents ) ;
113+ return await DecodeLatentsAsync ( promptOptions , schedulerOptions , decoderLatents ) ;
64114 }
65115
66116
67- protected async Task < DenseTensor < float > > DiffusePriorAsync ( SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
117+
118+ /// <summary>
119+ /// Run the Prior UNET diffusion
120+ /// </summary>
121+ /// <param name="prompt">The prompt.</param>
122+ /// <param name="schedulerOptions">The scheduler options.</param>
123+ /// <param name="promptEmbeddings">The prompt embeddings.</param>
124+ /// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
125+ /// <param name="progressCallback">The progress callback.</param>
126+ /// <param name="cancellationToken">The cancellation token.</param>
127+ /// <returns></returns>
128+ protected async Task < DenseTensor < float > > DiffusePriorAsync ( PromptOptions prompt , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
68129 {
69130 using ( var scheduler = GetScheduler ( schedulerOptions ) )
70131 {
71132 // Get timesteps
72133 var timesteps = GetTimesteps ( schedulerOptions , scheduler ) ;
73134
74135 // Create latent sample
75- var latents = scheduler . CreateRandomSample ( new [ ] { 1 , 16 , ( int ) Math . Ceiling ( schedulerOptions . Height / 42.67f ) , ( int ) Math . Ceiling ( schedulerOptions . Width / 42.67f ) } , scheduler . InitNoiseSigma ) ;
136+ var latents = await PrepareLatentsAsync ( prompt , schedulerOptions , scheduler , timesteps ) ;
76137
77138 // Get Model metadata
78139 var metadata = await _unet . GetMetadataAsync ( ) ;
@@ -89,18 +150,15 @@ protected async Task<DenseTensor<float>> DiffusePriorAsync(SchedulerOptions sche
89150 var inputLatent = performGuidance ? latents . Repeat ( 2 ) : latents ;
90151 var inputTensor = scheduler . ScaleInput ( inputLatent , timestep ) ;
91152 var timestepTensor = CreateTimestepTensor ( inputLatent , timestep ) ;
92- var imageEmbeds = new DenseTensor < float > ( new [ ] { performGuidance ? 2 : 1 , 1 , 768 } ) ;
93-
94- var outputChannels = performGuidance ? 2 : 1 ;
95- var outputDimension = inputTensor . Dimensions . ToArray ( ) ;
153+ var imageEmbeds = new DenseTensor < float > ( new [ ] { performGuidance ? 2 : 1 , 1 , _clipImageChannels } ) ;
96154 using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
97155 {
98156 inferenceParameters . AddInputTensor ( inputTensor ) ;
99157 inferenceParameters . AddInputTensor ( timestepTensor ) ;
100158 inferenceParameters . AddInputTensor ( promptEmbeddings . PooledPromptEmbeds ) ;
101159 inferenceParameters . AddInputTensor ( promptEmbeddings . PromptEmbeds ) ;
102160 inferenceParameters . AddInputTensor ( imageEmbeds ) ;
103- inferenceParameters . AddOutputBuffer ( outputDimension ) ;
161+ inferenceParameters . AddOutputBuffer ( inputTensor . Dimensions ) ;
104162
105163 var results = await _unet . RunInferenceAsync ( inferenceParameters ) ;
106164 using ( var result = results . First ( ) )
@@ -129,23 +187,33 @@ protected async Task<DenseTensor<float>> DiffusePriorAsync(SchedulerOptions sche
129187 }
130188
131189
132- protected async Task < DenseTensor < float > > DiffuseDecodeAsync ( DenseTensor < float > latentsPrior , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
190+ /// <summary>
191+ /// Run the Decoder UNET diffusion
192+ /// </summary>
193+ /// <param name="prompt">The prompt.</param>
194+ /// <param name="priorLatents">The prior latents.</param>
195+ /// <param name="schedulerOptions">The scheduler options.</param>
196+ /// <param name="promptEmbeddings">The prompt embeddings.</param>
197+ /// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
198+ /// <param name="progressCallback">The progress callback.</param>
199+ /// <param name="cancellationToken">The cancellation token.</param>
200+ /// <returns></returns>
201+ protected async Task < DenseTensor < float > > DiffuseDecodeAsync ( PromptOptions prompt , DenseTensor < float > priorLatents , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
133202 {
134203 using ( var scheduler = GetScheduler ( schedulerOptions ) )
135204 {
136205 // Get timesteps
137206 var timesteps = GetTimesteps ( schedulerOptions , scheduler ) ;
138207
139208 // Create latent sample
140- var latents = scheduler . CreateRandomSample ( new [ ] { 1 , 4 , ( int ) ( latentsPrior . Dimensions [ 2 ] * 10.67f ) , ( int ) ( latentsPrior . Dimensions [ 3 ] * 10.67f ) } , scheduler . InitNoiseSigma ) ;
209+ var latents = await PrepareDecoderLatentsAsync ( prompt , schedulerOptions , scheduler , timesteps , priorLatents ) ;
141210
142211 // Get Model metadata
143212 var metadata = await _decoderUnet . GetMetadataAsync ( ) ;
144213
145- var effnet = performGuidance
146- ? latentsPrior
147- : latentsPrior . Concatenate ( new DenseTensor < float > ( latentsPrior . Dimensions ) ) ;
148-
214+ var effnet = ! performGuidance
215+ ? priorLatents
216+ : priorLatents . Repeat ( 2 ) ;
149217
150218 // Loop though the timesteps
151219 var step = 0 ;
@@ -159,18 +227,15 @@ protected async Task<DenseTensor<float>> DiffuseDecodeAsync(DenseTensor<float> l
159227 var inputLatent = performGuidance ? latents . Repeat ( 2 ) : latents ;
160228 var inputTensor = scheduler . ScaleInput ( inputLatent , timestep ) ;
161229 var timestepTensor = CreateTimestepTensor ( inputLatent , timestep ) ;
162-
163- var outputChannels = performGuidance ? 2 : 1 ;
164- var outputDimension = inputTensor . Dimensions . ToArray ( ) ; //schedulerOptions.GetScaledDimension(outputChannels);
165230 using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
166231 {
167232 inferenceParameters . AddInputTensor ( inputTensor ) ;
168233 inferenceParameters . AddInputTensor ( timestepTensor ) ;
169234 inferenceParameters . AddInputTensor ( promptEmbeddings . PooledPromptEmbeds ) ;
170235 inferenceParameters . AddInputTensor ( effnet ) ;
171- inferenceParameters . AddOutputBuffer ( ) ;
236+ inferenceParameters . AddOutputBuffer ( inputTensor . Dimensions ) ;
172237
173- var results = _decoderUnet . RunInference ( inferenceParameters ) ;
238+ var results = await _decoderUnet . RunInferenceAsync ( inferenceParameters ) ;
174239 using ( var result = results . First ( ) )
175240 {
176241 var noisePred = result . ToDenseTensor ( ) ;
@@ -197,6 +262,13 @@ protected async Task<DenseTensor<float>> DiffuseDecodeAsync(DenseTensor<float> l
197262 }
198263
199264
265+ /// <summary>
266+ /// Decodes the latents.
267+ /// </summary>
268+ /// <param name="prompt">The prompt.</param>
269+ /// <param name="options">The options.</param>
270+ /// <param name="latents">The latents.</param>
271+ /// <returns></returns>
200272 protected override async Task < DenseTensor < float > > DecodeLatentsAsync ( PromptOptions prompt , SchedulerOptions options , DenseTensor < float > latents )
201273 {
202274 latents = latents . MultiplyBy ( _vaeDecoder . ScaleFactor ) ;
@@ -239,6 +311,19 @@ private DenseTensor<float> CreateTimestepTensor(DenseTensor<float> latents, int
239311 }
240312
241313
314+ /// <summary>
315+ /// Splits the prompt embeddings, Removes unconditional embeddings
316+ /// </summary>
317+ /// <param name="promptEmbeddings">The prompt embeddings.</param>
318+ /// <returns></returns>
319+ private PromptEmbeddingsResult SplitPromptEmbeddings ( PromptEmbeddingsResult promptEmbeddings )
320+ {
321+ return promptEmbeddings . PooledPromptEmbeds is null
322+ ? new PromptEmbeddingsResult ( promptEmbeddings . PromptEmbeds . SplitBatch ( ) . Last ( ) )
323+ : new PromptEmbeddingsResult ( promptEmbeddings . PromptEmbeds . SplitBatch ( ) . Last ( ) , promptEmbeddings . PooledPromptEmbeds . SplitBatch ( ) . Last ( ) ) ;
324+ }
325+
326+
242327 /// <summary>
243328 /// Gets the scheduler.
244329 /// </summary>
0 commit comments