11using Microsoft . Extensions . Logging ;
22using Microsoft . ML . OnnxRuntime . Tensors ;
33using OnnxStack . Core ;
4- using OnnxStack . Core . Image ;
54using OnnxStack . Core . Model ;
65using OnnxStack . StableDiffusion . Common ;
76using OnnxStack . StableDiffusion . Config ;
@@ -53,40 +52,48 @@ public StableCascadeDiffuser(UNetConditionModel priorUnet, UNetConditionModel de
5352 /// <returns></returns>
5453 public override async Task < DenseTensor < float > > DiffuseAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
5554 {
56- // Get Scheduler
57- using ( var schedulerPrior = GetScheduler ( schedulerOptions ) )
58- using ( var schedulerDecoder = GetScheduler ( schedulerOptions with { InferenceSteps = 10 , GuidanceScale = 0 } ) )
59- {
60- //----------------------------------------------------
61- // Prior Unet
62- //====================================================
55+ // Prior Unet
56+ var latentsPrior = await DiffusePriorAsync ( schedulerOptions , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
57+
58+ // Decoder Unet
59+ var schedulerOptionsDecoder = schedulerOptions with { InferenceSteps = 10 , GuidanceScale = 0 } ;
60+ var latents = await DiffuseDecodeAsync ( latentsPrior , schedulerOptionsDecoder , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
61+
62+ // Decode Latents
63+ return await DecodeLatentsAsync ( promptOptions , schedulerOptions , latents ) ;
64+ }
65+
6366
67+ protected async Task < DenseTensor < float > > DiffusePriorAsync ( SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
68+ {
69+ using ( var scheduler = GetScheduler ( schedulerOptions ) )
70+ {
6471 // Get timesteps
65- var timestepsPrior = GetTimesteps ( schedulerOptions , schedulerPrior ) ;
72+ var timesteps = GetTimesteps ( schedulerOptions , scheduler ) ;
6673
6774 // Create latent sample
68- var latentsPrior = schedulerPrior . CreateRandomSample ( new [ ] { 1 , 16 , ( int ) Math . Ceiling ( schedulerOptions . Height / 42.67f ) , ( int ) Math . Ceiling ( schedulerOptions . Width / 42.67f ) } , schedulerPrior . InitNoiseSigma ) ;
75+ var latents = scheduler . CreateRandomSample ( new [ ] { 1 , 16 , ( int ) Math . Ceiling ( schedulerOptions . Height / 42.67f ) , ( int ) Math . Ceiling ( schedulerOptions . Width / 42.67f ) } , scheduler . InitNoiseSigma ) ;
6976
7077 // Get Model metadata
71- var metadataPrior = await _unet . GetMetadataAsync ( ) ;
78+ var metadata = await _unet . GetMetadataAsync ( ) ;
7279
7380 // Loop though the timesteps
74- var stepPrior = 0 ;
75- foreach ( var timestep in timestepsPrior )
81+ var step = 0 ;
82+ foreach ( var timestep in timesteps )
7683 {
77- stepPrior ++ ;
84+ step ++ ;
7885 var stepTime = Stopwatch . GetTimestamp ( ) ;
7986 cancellationToken . ThrowIfCancellationRequested ( ) ;
8087
8188 // Create input tensor.
82- var inputLatent = performGuidance ? latentsPrior . Repeat ( 2 ) : latentsPrior ;
83- var inputTensor = schedulerPrior . ScaleInput ( inputLatent , timestep ) ;
84- var timestepTensor = CreateTimestepTensor ( inputLatent , timestep ) ;
85- var imageEmbeds = new DenseTensor < float > ( performGuidance ? new [ ] { 2 , 1 , 768 } : new [ ] { 1 , 1 , 768 } ) ;
89+ var inputLatent = performGuidance ? latents . Repeat ( 2 ) : latents ;
90+ var inputTensor = scheduler . ScaleInput ( inputLatent , timestep ) ;
91+ var timestepTensor = CreateTimestepTensor ( inputLatent , timestep ) ;
92+ var imageEmbeds = new DenseTensor < float > ( new [ ] { performGuidance ? 2 : 1 , 1 , 768 } ) ;
8693
8794 var outputChannels = performGuidance ? 2 : 1 ;
88- var outputDimension = inputTensor . Dimensions . ToArray ( ) ; //schedulerOptions.GetScaledDimension(outputChannels);
89- using ( var inferenceParameters = new OnnxInferenceParameters ( metadataPrior ) )
95+ var outputDimension = inputTensor . Dimensions . ToArray ( ) ;
96+ using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
9097 {
9198 inferenceParameters . AddInputTensor ( inputTensor ) ;
9299 inferenceParameters . AddInputTensor ( timestepTensor ) ;
@@ -105,58 +112,57 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
105112 noisePred = PerformGuidance ( noisePred , schedulerOptions . GuidanceScale ) ;
106113
107114 // Scheduler Step
108- latentsPrior = schedulerPrior . Step ( noisePred , timestep , latentsPrior ) . Result ;
115+ latents = scheduler . Step ( noisePred , timestep , latents ) . Result ;
109116 }
110117 }
111118
112- ReportProgress ( progressCallback , stepPrior , timestepsPrior . Count , latentsPrior ) ;
113- _logger ? . LogEnd ( LogLevel . Debug , $ "Step { stepPrior } /{ timestepsPrior . Count } ", stepTime ) ;
119+ ReportProgress ( progressCallback , step , timesteps . Count , latents ) ;
120+ _logger ? . LogEnd ( LogLevel . Debug , $ "Prior Step { step } /{ timesteps . Count } ", stepTime ) ;
114121 }
115122
116123 // Unload if required
117124 if ( _memoryMode == MemoryModeType . Minimum )
118125 await _unet . UnloadAsync ( ) ;
119126
127+ return latents ;
128+ }
129+ }
120130
121131
122-
123-
124- //----------------------------------------------------
125- // Decoder Unet
126- //====================================================
127-
132+ protected async Task < DenseTensor < float > > DiffuseDecodeAsync ( DenseTensor < float > latentsPrior , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
133+ {
134+ using ( var scheduler = GetScheduler ( schedulerOptions ) )
135+ {
128136 // Get timesteps
129- var timestepsDecoder = GetTimesteps ( schedulerOptions , schedulerDecoder ) ;
137+ var timesteps = GetTimesteps ( schedulerOptions , scheduler ) ;
130138
131139 // Create latent sample
132-
133- var latentsDecoder = schedulerDecoder . CreateRandomSample ( new [ ] { 1 , 4 , ( int ) ( latentsPrior . Dimensions [ 2 ] * 10.67f ) , ( int ) ( latentsPrior . Dimensions [ 3 ] * 10.67f ) } , schedulerDecoder . InitNoiseSigma ) ;
140+ var latents = scheduler . CreateRandomSample ( new [ ] { 1 , 4 , ( int ) ( latentsPrior . Dimensions [ 2 ] * 10.67f ) , ( int ) ( latentsPrior . Dimensions [ 3 ] * 10.67f ) } , scheduler . InitNoiseSigma ) ;
134141
135142 // Get Model metadata
136- var metadataDecoder = await _decoderUnet . GetMetadataAsync ( ) ;
143+ var metadata = await _decoderUnet . GetMetadataAsync ( ) ;
137144
138145 var effnet = performGuidance
139146 ? latentsPrior
140147 : latentsPrior . Concatenate ( new DenseTensor < float > ( latentsPrior . Dimensions ) ) ;
141148
142149
143150 // Loop though the timesteps
144- var stepDecoder = 0 ;
145- foreach ( var timestep in timestepsDecoder )
151+ var step = 0 ;
152+ foreach ( var timestep in timesteps )
146153 {
147- stepDecoder ++ ;
154+ step ++ ;
148155 var stepTime = Stopwatch . GetTimestamp ( ) ;
149156 cancellationToken . ThrowIfCancellationRequested ( ) ;
150157
151158 // Create input tensor.
152- var inputLatent = performGuidance ? latentsDecoder . Repeat ( 2 ) : latentsDecoder ;
153- var inputTensor = schedulerDecoder . ScaleInput ( inputLatent , timestep ) ;
159+ var inputLatent = performGuidance ? latents . Repeat ( 2 ) : latents ;
160+ var inputTensor = scheduler . ScaleInput ( inputLatent , timestep ) ;
154161 var timestepTensor = CreateTimestepTensor ( inputLatent , timestep ) ;
155162
156-
157163 var outputChannels = performGuidance ? 2 : 1 ;
158164 var outputDimension = inputTensor . Dimensions . ToArray ( ) ; //schedulerOptions.GetScaledDimension(outputChannels);
159- using ( var inferenceParameters = new OnnxInferenceParameters ( metadataDecoder ) )
165+ using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
160166 {
161167 inferenceParameters . AddInputTensor ( inputTensor ) ;
162168 inferenceParameters . AddInputTensor ( timestepTensor ) ;
@@ -174,20 +180,19 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
174180 noisePred = PerformGuidance ( noisePred , schedulerOptions . GuidanceScale ) ;
175181
176182 // Scheduler Step
177- latentsDecoder = schedulerDecoder . Step ( noisePred , timestep , latentsDecoder ) . Result ;
183+ latents = scheduler . Step ( noisePred , timestep , latents ) . Result ;
178184 }
179185 }
180186
187+ ReportProgress ( progressCallback , step , timesteps . Count , latents ) ;
188+ _logger ? . LogEnd ( LogLevel . Debug , $ "Decoder Step { step } /{ timesteps . Count } ", stepTime ) ;
181189 }
182190
183- var testlatentsPrior = new OnnxImage ( latentsPrior ) ;
184- var testlatentsDecoder = new OnnxImage ( latentsDecoder ) ;
185- await testlatentsPrior . SaveAsync ( "D:\\ testlatentsPrior.png" ) ;
186- await testlatentsDecoder . SaveAsync ( "D:\\ latentsDecoder.png" ) ;
187-
191+ // Unload if required
192+ if ( _memoryMode == MemoryModeType . Minimum )
193+ await _unet . UnloadAsync ( ) ;
188194
189- // Decode Latents
190- return await DecodeLatentsAsync ( promptOptions , schedulerOptions , latentsDecoder ) ;
195+ return latents ;
191196 }
192197 }
193198
0 commit comments