@@ -34,6 +34,8 @@ public class StableDiffusionPipeline : PipelineBase
3434 protected IReadOnlyList < SchedulerType > _supportedSchedulers ;
3535 protected SchedulerOptions _defaultSchedulerOptions ;
3636
37+ protected sealed record BatchResultInternal ( SchedulerOptions SchedulerOptions , List < DenseTensor < float > > Result ) ;
38+
3739 /// <summary>
3840 /// Initializes a new instance of the <see cref="StableDiffusionPipeline"/> class.
3941 /// </summary>
@@ -165,35 +167,10 @@ public override void ValidateInputs(PromptOptions promptOptions, SchedulerOption
165167 /// <returns></returns>
166168 public override async Task < DenseTensor < float > > RunAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
167169 {
168- var diffuseTime = _logger ? . LogBegin ( "Diffuser starting..." ) ;
169- var options = GetSchedulerOptionsOrDefault ( schedulerOptions ) ;
170- _logger ? . Log ( $ "Model: { Name } , Pipeline: { PipelineType } , Diffuser: { promptOptions . DiffuserType } , Scheduler: { options . SchedulerType } ") ;
171-
172- // Check guidance
173- var performGuidance = ShouldPerformGuidance ( options ) ;
174-
175- // Process prompts
176- var promptEmbeddings = await CreatePromptEmbedsAsync ( promptOptions , performGuidance ) ;
177-
178- // Create Diffuser
179- var diffuser = CreateDiffuser ( promptOptions . DiffuserType , controlNet ) ;
180-
181- // Diffuse
182- var tensorResult = default ( DenseTensor < float > ) ;
183- if ( promptOptions . HasInputVideo )
184- {
185- await foreach ( var frameTensor in DiffuseVideoAsync ( diffuser , promptOptions , schedulerOptions , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) )
186- {
187- tensorResult = tensorResult . Concatenate ( frameTensor ) ;
188- }
189- }
190- else
191- {
192- tensorResult = await DiffuseImageAsync ( diffuser , promptOptions , schedulerOptions , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
193- }
194-
195- _logger ? . LogEnd ( $ "Diffuser complete", diffuseTime ) ;
196- return tensorResult ;
170+ var tensors = await RunInternalAsync ( promptOptions , schedulerOptions , controlNet , progressCallback , cancellationToken ) ;
171+ return tensors . Count == 1
172+ ? tensors . First ( ) // ImageTensor
173+ : tensors . Join ( ) ; // VideoTensor
197174 }
198175
199176
@@ -209,45 +186,13 @@ public override async Task<DenseTensor<float>> RunAsync(PromptOptions promptOpti
209186 /// <returns></returns>
210187 public override async IAsyncEnumerable < BatchResult > RunBatchAsync ( BatchOptions batchOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
211188 {
212- var diffuseBatchTime = _logger ? . LogBegin ( "Batch Diffuser starting..." ) ;
213- var options = GetSchedulerOptionsOrDefault ( schedulerOptions ) ;
214- _logger ? . Log ( $ "Model: { Name } , Pipeline: { PipelineType } , Diffuser: { promptOptions . DiffuserType } , Scheduler: { options . SchedulerType } ") ;
215- _logger ? . Log ( $ "BatchType: { batchOptions . BatchType } , ValueFrom: { batchOptions . ValueFrom } , ValueTo: { batchOptions . ValueTo } , Increment: { batchOptions . Increment } ") ;
216-
217- // Check guidance
218- var performGuidance = ShouldPerformGuidance ( options ) ;
219-
220- // Process prompts
221- var promptEmbeddings = await CreatePromptEmbedsAsync ( promptOptions , performGuidance ) ;
222-
223- // Generate batch options
224- var batchSchedulerOptions = BatchGenerator . GenerateBatch ( this , batchOptions , options ) ;
225-
226- // Create Diffuser
227- var diffuser = CreateDiffuser ( promptOptions . DiffuserType , controlNet ) ;
228-
229- // Diffuse
230- var batchIndex = 1 ;
231- var batchSchedulerCallback = CreateBatchCallback ( progressCallback , batchSchedulerOptions . Count , ( ) => batchIndex ) ;
232- foreach ( var batchSchedulerOption in batchSchedulerOptions )
189+ await foreach ( var batchResult in RunBatchInternalAsync ( batchOptions , promptOptions , schedulerOptions , controlNet , progressCallback , cancellationToken ) )
233190 {
234- var tensorResult = default ( DenseTensor < float > ) ;
235- if ( promptOptions . HasInputVideo )
236- {
237- await foreach ( var frameTensor in DiffuseVideoAsync ( diffuser , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) )
238- {
239- tensorResult = tensorResult . Concatenate ( frameTensor ) ;
240- }
241- }
242- else
243- {
244- tensorResult = await DiffuseImageAsync ( diffuser , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
245- }
246- yield return new BatchResult ( batchSchedulerOption , tensorResult ) ;
247- batchIndex ++ ;
191+ var tensor = batchResult . Result . Count == 1
192+ ? batchResult . Result . First ( ) // ImageTensor
193+ : batchResult . Result . Join ( ) ; // VideoTensor
194+ yield return new BatchResult ( batchResult . SchedulerOptions , tensor ) ;
248195 }
249-
250- _logger ? . LogEnd ( $ "Batch Diffuser complete", diffuseBatchTime ) ;
251196 }
252197
253198
@@ -262,22 +207,8 @@ public override async IAsyncEnumerable<BatchResult> RunBatchAsync(BatchOptions b
262207 /// <returns></returns>
263208 public override async Task < OnnxImage > GenerateImageAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
264209 {
265- var diffuseTime = _logger ? . LogBegin ( "Diffuser starting..." ) ;
266- var options = GetSchedulerOptionsOrDefault ( schedulerOptions ) ;
267- _logger ? . Log ( $ "Model: { Name } , Pipeline: { PipelineType } , Diffuser: { promptOptions . DiffuserType } , Scheduler: { options . SchedulerType } ") ;
268-
269- // Check guidance
270- var performGuidance = ShouldPerformGuidance ( options ) ;
271-
272- // Process prompts
273- var promptEmbeddings = await CreatePromptEmbedsAsync ( promptOptions , performGuidance ) ;
274-
275- // Create Diffuser
276- var diffuser = CreateDiffuser ( promptOptions . DiffuserType , controlNet ) ;
277-
278- var imageResult = await DiffuseImageAsync ( diffuser , promptOptions , options , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
279-
280- return new OnnxImage ( imageResult ) ;
210+ var tensors = await RunInternalAsync ( promptOptions , schedulerOptions , controlNet , progressCallback , cancellationToken ) ;
211+ return new OnnxImage ( tensors . First ( ) ) ;
281212 }
282213
283214
@@ -293,47 +224,58 @@ public override async Task<OnnxImage> GenerateImageAsync(PromptOptions promptOpt
293224 /// <returns></returns>
294225 public override async IAsyncEnumerable < BatchImageResult > GenerateImageBatchAsync ( BatchOptions batchOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
295226 {
296- var diffuseBatchTime = _logger ? . LogBegin ( "Batch Diffuser starting..." ) ;
297- var options = GetSchedulerOptionsOrDefault ( schedulerOptions ) ;
298- _logger ? . Log ( $ "Model: { Name } , Pipeline: { PipelineType } , Diffuser: { promptOptions . DiffuserType } , Scheduler: { options . SchedulerType } ") ;
299- _logger ? . Log ( $ "BatchType: { batchOptions . BatchType } , ValueFrom: { batchOptions . ValueFrom } , ValueTo: { batchOptions . ValueTo } , Increment: { batchOptions . Increment } ") ;
300-
301- // Check guidance
302- var performGuidance = ShouldPerformGuidance ( options ) ;
227+ await foreach ( var batchResult in RunBatchInternalAsync ( batchOptions , promptOptions , schedulerOptions , controlNet , progressCallback , cancellationToken ) )
228+ {
229+ yield return new BatchImageResult ( batchResult . SchedulerOptions , new OnnxImage ( batchResult . Result . First ( ) ) ) ;
230+ }
231+ }
303232
304- // Process prompts
305- var promptEmbeddings = await CreatePromptEmbedsAsync ( promptOptions , performGuidance ) ;
306233
307- // Generate batch options
308- var batchSchedulerOptions = BatchGenerator . GenerateBatch ( this , batchOptions , options ) ;
234+ /// <summary>
235+ /// Runs the pipeline returning the result as an OnnxVideo.
236+ /// </summary>
237+ /// <param name="promptOptions">The prompt options.</param>
238+ /// <param name="schedulerOptions">The scheduler options.</param>
239+ /// <param name="controlNet">The control net.</param>
240+ /// <param name="progressCallback">The progress callback.</param>
241+ /// <param name="cancellationToken">The cancellation token.</param>
242+ /// <returns></returns>
243+ public override async Task < OnnxVideo > GenerateVideoAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
244+ {
245+ var tensors = await RunInternalAsync ( promptOptions , schedulerOptions , controlNet , progressCallback , cancellationToken ) ;
246+ return new OnnxVideo ( promptOptions . InputVideo . Info , tensors ) ;
247+ }
309248
310- // Create Diffuser
311- var diffuser = CreateDiffuser ( promptOptions . DiffuserType , controlNet ) ;
312249
313- // Diffuse
314- var batchIndex = 1 ;
315- var batchSchedulerCallback = CreateBatchCallback ( progressCallback , batchSchedulerOptions . Count , ( ) => batchIndex ) ;
316- foreach ( var batchSchedulerOption in batchSchedulerOptions )
250+ /// <summary>
251+ /// Runs the batch pipeline returning the result as an OnnxVideo.
252+ /// </summary>
253+ /// <param name="batchOptions">The batch options.</param>
254+ /// <param name="promptOptions">The prompt options.</param>
255+ /// <param name="schedulerOptions">The scheduler options.</param>
256+ /// <param name="controlNet">The control net.</param>
257+ /// <param name="progressCallback">The progress callback.</param>
258+ /// <param name="cancellationToken">The cancellation token.</param>
259+ /// <returns></returns>
260+ public override async IAsyncEnumerable < BatchVideoResult > GenerateVideoBatchAsync ( BatchOptions batchOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
261+ {
262+ await foreach ( var batchResult in RunBatchInternalAsync ( batchOptions , promptOptions , schedulerOptions , controlNet , progressCallback , cancellationToken ) )
317263 {
318- var tensorResult = await DiffuseImageAsync ( diffuser , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
319- yield return new BatchImageResult ( batchSchedulerOption , new OnnxImage ( tensorResult ) ) ;
320- batchIndex ++ ;
264+ yield return new BatchVideoResult ( batchResult . SchedulerOptions , new OnnxVideo ( promptOptions . InputVideo . Info , batchResult . Result ) ) ;
321265 }
322-
323- _logger ? . LogEnd ( $ "Batch Diffuser complete", diffuseBatchTime ) ;
324266 }
325267
326268
327269 /// <summary>
328- /// Runs the pipeline returning the result as an OnnxVideo.
270+ /// Runs the pipeline
329271 /// </summary>
330272 /// <param name="promptOptions">The prompt options.</param>
331273 /// <param name="schedulerOptions">The scheduler options.</param>
332274 /// <param name="controlNet">The control net.</param>
333275 /// <param name="progressCallback">The progress callback.</param>
334276 /// <param name="cancellationToken">The cancellation token.</param>
335277 /// <returns></returns>
336- public override async Task < OnnxVideo > GenerateVideoAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
278+ protected virtual async Task < List < DenseTensor < float > > > RunInternalAsync ( PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
337279 {
338280 var diffuseTime = _logger ? . LogBegin ( "Diffuser starting..." ) ;
339281 var options = GetSchedulerOptionsOrDefault ( schedulerOptions ) ;
@@ -348,17 +290,30 @@ public override async Task<OnnxVideo> GenerateVideoAsync(PromptOptions promptOpt
348290 // Create Diffuser
349291 var diffuser = CreateDiffuser ( promptOptions . DiffuserType , controlNet ) ;
350292
351- var frames = new List < OnnxImage > ( ) ;
352- await foreach ( var frameTensor in DiffuseVideoAsync ( diffuser , promptOptions , options , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) )
293+ // Diffuse
294+ var tensorResult = new List < DenseTensor < float > > ( ) ;
295+ if ( promptOptions . HasInputVideo )
353296 {
354- frames . Add ( new OnnxImage ( frameTensor ) ) ;
297+ var frameIndex = 1 ;
298+ var frameSchedulerCallback = CreateBatchCallback ( progressCallback , promptOptions . InputVideo . Frames . Count , ( ) => frameIndex ) ;
299+ await foreach ( var frameTensor in DiffuseVideoAsync ( diffuser , promptOptions , options , promptEmbeddings , performGuidance , frameSchedulerCallback , cancellationToken ) )
300+ {
301+ frameIndex ++ ;
302+ tensorResult . Add ( frameTensor ) ;
303+ }
355304 }
356- return new OnnxVideo ( promptOptions . InputVideo . Info , frames ) ;
305+ else
306+ {
307+ tensorResult . Add ( await DiffuseImageAsync ( diffuser , promptOptions , options , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ) ;
308+ }
309+
310+ _logger ? . LogEnd ( $ "Diffuser complete", diffuseTime ) ;
311+ return tensorResult ;
357312 }
358313
359314
360315 /// <summary>
361- /// Runs the batch pipeline returning the result as an OnnxVideo .
316+ /// Runs the pipeline batch .
362317 /// </summary>
363318 /// <param name="batchOptions">The batch options.</param>
364319 /// <param name="promptOptions">The prompt options.</param>
@@ -367,7 +322,7 @@ public override async Task<OnnxVideo> GenerateVideoAsync(PromptOptions promptOpt
367322 /// <param name="progressCallback">The progress callback.</param>
368323 /// <param name="cancellationToken">The cancellation token.</param>
369324 /// <returns></returns>
370- public override async IAsyncEnumerable < BatchVideoResult > GenerateVideoBatchAsync ( BatchOptions batchOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
325+ protected virtual async IAsyncEnumerable < BatchResultInternal > RunBatchInternalAsync ( BatchOptions batchOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions = default , ControlNetModel controlNet = default , Action < DiffusionProgress > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
371326 {
372327 var diffuseBatchTime = _logger ? . LogBegin ( "Batch Diffuser starting..." ) ;
373328 var options = GetSchedulerOptionsOrDefault ( schedulerOptions ) ;
@@ -387,19 +342,26 @@ public override async IAsyncEnumerable<BatchVideoResult> GenerateVideoBatchAsync
387342 var diffuser = CreateDiffuser ( promptOptions . DiffuserType , controlNet ) ;
388343
389344 // Diffuse
390- var batchIndex = 1 ;
345+ var batchIndex = 1 ; // TODO: Video batch callback shoud be (BatchIndex + FrameIndex), not (BatchIndex + StepIndex)
391346 var batchSchedulerCallback = CreateBatchCallback ( progressCallback , batchSchedulerOptions . Count , ( ) => batchIndex ) ;
392347 foreach ( var batchSchedulerOption in batchSchedulerOptions )
393348 {
394- var frames = new List < OnnxImage > ( ) ;
395- await foreach ( var frameTensor in DiffuseVideoAsync ( diffuser , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) )
349+ var tensorResult = new List < DenseTensor < float > > ( ) ;
350+ if ( promptOptions . HasInputVideo )
396351 {
397- frames . Add ( new OnnxImage ( frameTensor ) ) ;
352+ await foreach ( var frameTensor in DiffuseVideoAsync ( diffuser , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , batchSchedulerCallback , cancellationToken ) )
353+ {
354+ tensorResult . Add ( frameTensor ) ;
355+ }
356+ }
357+ else
358+ {
359+ tensorResult . Add ( await DiffuseImageAsync ( diffuser , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , batchSchedulerCallback , cancellationToken ) ) ;
398360 }
399- yield return new BatchVideoResult ( batchSchedulerOption , new OnnxVideo ( promptOptions . InputVideo . Info , frames ) ) ;
361+
400362 batchIndex ++ ;
363+ yield return new BatchResultInternal ( batchSchedulerOption , tensorResult ) ;
401364 }
402-
403365 _logger ? . LogEnd ( $ "Batch Diffuser complete", diffuseBatchTime ) ;
404366 }
405367
@@ -623,5 +585,4 @@ public static StableDiffusionPipeline CreatePipeline(string modelFolder, ModelTy
623585 return CreatePipeline ( ModelFactory . CreateModelSet ( modelFolder , DiffuserPipelineType . StableDiffusion , modelType , deviceId , executionProvider , memoryMode ) , logger ) ;
624586 }
625587 }
626-
627588}
0 commit comments