@@ -105,7 +105,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(ModelOptions modelOpt
105105 // Create random seed if none was set
106106 schedulerOptions . Seed = schedulerOptions . Seed > 0 ? schedulerOptions . Seed : Random . Shared . Next ( ) ;
107107
108- var diffuseTime = _logger ? . LogBegin ( "Diffuse starting..." ) ;
108+ var diffuseTime = _logger ? . LogBegin ( "Diffuser starting..." ) ;
109109 _logger ? . Log ( $ "Model: { modelOptions . Name } , Pipeline: { modelOptions . PipelineType } , Diffuser: { promptOptions . DiffuserType } , Scheduler: { schedulerOptions . SchedulerType } ") ;
110110
111111 // Check guidance
@@ -114,36 +114,15 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(ModelOptions modelOpt
114114 // Process prompts
115115 var promptEmbeddings = await _promptService . CreatePromptAsync ( modelOptions . BaseModel , promptOptions , performGuidance ) ;
116116
117- // If video input, process frames
118- if ( promptOptions . HasInputVideo )
119- {
120- var frameIndex = 0 ;
121- DenseTensor < float > videoTensor = null ;
122- var videoFrames = promptOptions . InputVideo . VideoFrames . Frames ;
123- var schedulerFrameCallback = CreateBatchCallback ( progressCallback , videoFrames . Count , ( ) => frameIndex ) ;
124- foreach ( var videoFrame in videoFrames )
125- {
126- frameIndex ++ ;
127- promptOptions . InputImage = promptOptions . DiffuserType == DiffuserType . ControlNet ? default : new InputImage ( videoFrame ) ;
128- promptOptions . InputContolImage = promptOptions . DiffuserType == DiffuserType . ImageToImage ? default : new InputImage ( videoFrame ) ;
129- var frameResultTensor = await SchedulerStepAsync ( modelOptions , promptOptions , schedulerOptions , promptEmbeddings , performGuidance , schedulerFrameCallback , cancellationToken ) ;
130-
131- // Frame Progress
132- ReportBatchProgress ( progressCallback , frameIndex , videoFrames . Count , frameResultTensor ) ;
117+ var tensorResult = promptOptions . HasInputVideo
118+ ? await DiffuseVideoAsync ( modelOptions , promptOptions , schedulerOptions , promptEmbeddings , performGuidance , progressCallback , cancellationToken )
119+ : await DiffuseImageAsync ( modelOptions , promptOptions , schedulerOptions , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
133120
134- // Concatenate frame
135- videoTensor = videoTensor . Concatenate ( frameResultTensor ) ;
136- }
121+ _logger ? . LogEnd ( $ "Diffuser complete" , diffuseTime ) ;
122+ return tensorResult ;
123+ }
137124
138- _logger ? . LogEnd ( $ "Diffuse complete", diffuseTime ) ;
139- return videoTensor ;
140- }
141125
142- // Run Scheduler steps
143- var schedulerResult = await SchedulerStepAsync ( modelOptions , promptOptions , schedulerOptions , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
144- _logger ? . LogEnd ( $ "Diffuse complete", diffuseTime ) ;
145- return schedulerResult ;
146- }
147126
148127
149128
@@ -180,13 +159,73 @@ public virtual async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(ModelOption
180159 var batchSchedulerCallback = CreateBatchCallback ( progressCallback , batchSchedulerOptions . Count , ( ) => batchIndex ) ;
181160 foreach ( var batchSchedulerOption in batchSchedulerOptions )
182161 {
183- var diffuseTime = _logger ? . LogBegin ( "Diffuse starting..." ) ;
184- yield return new BatchResult ( batchSchedulerOption , await SchedulerStepAsync ( modelOptions , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , batchSchedulerCallback , cancellationToken ) ) ;
185- _logger ? . LogEnd ( $ "Diffuse complete", diffuseTime ) ;
162+ var tensorResult = promptOptions . HasInputVideo
163+ ? await DiffuseVideoAsync ( modelOptions , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , progressCallback , cancellationToken )
164+ : await DiffuseImageAsync ( modelOptions , promptOptions , batchSchedulerOption , promptEmbeddings , performGuidance , batchSchedulerCallback , cancellationToken ) ;
165+
166+ yield return new BatchResult ( batchSchedulerOption , tensorResult ) ;
186167 batchIndex ++ ;
187168 }
188169
189- _logger ? . LogEnd ( $ "Diffuse batch complete", diffuseBatchTime ) ;
170+ _logger ? . LogEnd ( $ "Batch Diffuser complete", diffuseBatchTime ) ;
171+ }
172+
173+
174+ /// <summary>
175+ /// Diffuses the image.
176+ /// </summary>
177+ /// <param name="modelOptions">The model options.</param>
178+ /// <param name="promptOptions">The prompt options.</param>
179+ /// <param name="schedulerOptions">The scheduler options.</param>
180+ /// <param name="promptEmbeddings">The prompt embeddings.</param>
181+ /// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
182+ /// <param name="progressCallback">The progress callback.</param>
183+ /// <param name="cancellationToken">The cancellation token.</param>
184+ /// <returns></returns>
185+ protected virtual async Task < DenseTensor < float > > DiffuseImageAsync ( ModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
186+ {
187+ var diffuseTime = _logger ? . LogBegin ( "Image Diffuser starting..." ) ;
188+ var schedulerResult = await SchedulerStepAsync ( modelOptions , promptOptions , schedulerOptions , promptEmbeddings , performGuidance , progressCallback , cancellationToken ) ;
189+ _logger ? . LogEnd ( $ "Image Diffuser complete", diffuseTime ) ;
190+ return schedulerResult ;
191+ }
192+
193+
194+ /// <summary>
195+ /// Diffuses the video.
196+ /// </summary>
197+ /// <param name="modelOptions">The model options.</param>
198+ /// <param name="promptOptions">The prompt options.</param>
199+ /// <param name="schedulerOptions">The scheduler options.</param>
200+ /// <param name="promptEmbeddings">The prompt embeddings.</param>
201+ /// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
202+ /// <param name="progressCallback">The progress callback.</param>
203+ /// <param name="cancellationToken">The cancellation token.</param>
204+ /// <returns></returns>
205+ protected virtual async Task < DenseTensor < float > > DiffuseVideoAsync ( ModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
206+ {
207+ var diffuseTime = _logger ? . LogBegin ( "Video Diffuser starting..." ) ;
208+
209+ var frameIndex = 0 ;
210+ DenseTensor < float > videoTensor = null ;
211+ var videoFrames = promptOptions . InputVideo . VideoFrames . Frames ;
212+ var schedulerFrameCallback = CreateBatchCallback ( progressCallback , videoFrames . Count , ( ) => frameIndex ) ;
213+ foreach ( var videoFrame in videoFrames )
214+ {
215+ frameIndex ++ ;
216+ promptOptions . InputImage = promptOptions . DiffuserType == DiffuserType . ControlNet ? default : new InputImage ( videoFrame ) ;
217+ promptOptions . InputContolImage = promptOptions . DiffuserType == DiffuserType . ImageToImage ? default : new InputImage ( videoFrame ) ;
218+ var frameResultTensor = await SchedulerStepAsync ( modelOptions , promptOptions , schedulerOptions , promptEmbeddings , performGuidance , schedulerFrameCallback , cancellationToken ) ;
219+
220+ // Frame Progress
221+ ReportBatchProgress ( progressCallback , frameIndex , videoFrames . Count , frameResultTensor ) ;
222+
223+ // Concatenate frame
224+ videoTensor = videoTensor . Concatenate ( frameResultTensor ) ;
225+ }
226+
227+ _logger ? . LogEnd ( $ "Video Diffuser complete", diffuseTime ) ;
228+ return videoTensor ;
190229 }
191230
192231
0 commit comments