11using Microsoft . ML . OnnxRuntime . Tensors ;
22using OnnxStack . Core ;
3- using OnnxStack . Core . Config ;
43using OnnxStack . Core . Services ;
54using OnnxStack . StableDiffusion . Common ;
65using OnnxStack . StableDiffusion . Config ;
@@ -26,6 +25,7 @@ namespace OnnxStack.StableDiffusion.Services
2625 /// <seealso cref="OnnxStack.StableDiffusion.Common.IStableDiffusionService" />
2726 public sealed class StableDiffusionService : IStableDiffusionService
2827 {
28+ private readonly IVideoService _videoService ;
2929 private readonly IOnnxModelService _modelService ;
3030 private readonly StableDiffusionConfig _configuration ;
3131 private readonly ConcurrentDictionary < DiffuserPipelineType , IPipeline > _pipelines ;
@@ -34,10 +34,11 @@ public sealed class StableDiffusionService : IStableDiffusionService
3434 /// Initializes a new instance of the <see cref="StableDiffusionService"/> class.
3535 /// </summary>
3636 /// <param name="schedulerService">The scheduler service.</param>
37- public StableDiffusionService ( StableDiffusionConfig configuration , IOnnxModelService onnxModelService , IEnumerable < IPipeline > pipelines )
37+ public StableDiffusionService ( StableDiffusionConfig configuration , IOnnxModelService onnxModelService , IVideoService videoService , IEnumerable < IPipeline > pipelines )
3838 {
3939 _configuration = configuration ;
4040 _modelService = onnxModelService ;
41+ _videoService = videoService ;
4142 _pipelines = pipelines . ToConcurrentDictionary ( k => k . PipelineType , k => k ) ;
4243 }
4344
@@ -115,9 +116,11 @@ public async Task<Image<Rgba32>> GenerateAsImageAsync(StableDiffusionModelSet mo
115116 /// <returns>The diffusion result as <see cref="byte[]"/></returns>
116117 public async Task < byte [ ] > GenerateAsBytesAsync ( StableDiffusionModelSet model , PromptOptions prompt , SchedulerOptions options , Action < int , int > progressCallback = null , CancellationToken cancellationToken = default )
117118 {
118- return await GenerateAsync ( model , prompt , options , progressCallback , cancellationToken )
119- . ContinueWith ( t => t . Result . ToImageBytes ( ) , cancellationToken )
120- . ConfigureAwait ( false ) ;
119+ var generateResult = await GenerateAsync ( model , prompt , options , progressCallback , cancellationToken ) . ConfigureAwait ( false ) ;
120+ if ( ! prompt . HasInputVideo )
121+ return generateResult . ToImageBytes ( ) ;
122+
123+ return await GetVideoResultAsBytesAsync ( options , generateResult , cancellationToken ) . ConfigureAwait ( false ) ;
121124 }
122125
123126
@@ -131,9 +134,11 @@ public async Task<byte[]> GenerateAsBytesAsync(StableDiffusionModelSet model, Pr
131134 /// <returns>The diffusion result as <see cref="System.IO.Stream"/></returns>
132135 public async Task < Stream > GenerateAsStreamAsync ( StableDiffusionModelSet model , PromptOptions prompt , SchedulerOptions options , Action < int , int > progressCallback = null , CancellationToken cancellationToken = default )
133136 {
134- return await GenerateAsync ( model , prompt , options , progressCallback , cancellationToken )
135- . ContinueWith ( t => t . Result . ToImageStream ( ) , cancellationToken )
136- . ConfigureAwait ( false ) ;
137+ var generateResult = await GenerateAsync ( model , prompt , options , progressCallback , cancellationToken ) . ConfigureAwait ( false ) ;
138+ if ( ! prompt . HasInputVideo )
139+ return generateResult . ToImageStream ( ) ;
140+
141+ return await GetVideoResultAsStreamAsync ( options , generateResult , cancellationToken ) . ConfigureAwait ( false ) ;
137142 }
138143
139144
@@ -183,7 +188,12 @@ public async IAsyncEnumerable<Image<Rgba32>> GenerateBatchAsImageAsync(StableDif
183188 public async IAsyncEnumerable < byte [ ] > GenerateBatchAsBytesAsync ( StableDiffusionModelSet modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < int , int , int , int > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
184189 {
185190 await foreach ( var result in GenerateBatchAsync ( modelOptions , promptOptions , schedulerOptions , batchOptions , progressCallback , cancellationToken ) )
186- yield return result . ImageResult . ToImageBytes ( ) ;
191+ {
192+ if ( ! promptOptions . HasInputVideo )
193+ yield return result . ImageResult . ToImageBytes ( ) ;
194+
195+ yield return await GetVideoResultAsBytesAsync ( schedulerOptions , result . ImageResult , cancellationToken ) . ConfigureAwait ( false ) ;
196+ }
187197 }
188198
189199
@@ -200,7 +210,12 @@ public async IAsyncEnumerable<byte[]> GenerateBatchAsBytesAsync(StableDiffusionM
200210 public async IAsyncEnumerable < Stream > GenerateBatchAsStreamAsync ( StableDiffusionModelSet modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < int , int , int , int > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
201211 {
202212 await foreach ( var result in GenerateBatchAsync ( modelOptions , promptOptions , schedulerOptions , batchOptions , progressCallback , cancellationToken ) )
203- yield return result . ImageResult . ToImageStream ( ) ;
213+ {
214+ if ( ! promptOptions . HasInputVideo )
215+ yield return result . ImageResult . ToImageStream ( ) ;
216+
217+ yield return await GetVideoResultAsStreamAsync ( schedulerOptions , result . ImageResult , cancellationToken ) . ConfigureAwait ( false ) ;
218+ }
204219 }
205220
206221
@@ -237,6 +252,21 @@ private IAsyncEnumerable<BatchResult> DiffuseBatchAsync(StableDiffusionModelSet
237252 return diffuser . DiffuseBatchAsync ( modelOptions , promptOptions , schedulerOptions , batchOptions , progress , cancellationToken ) ;
238253 }
239254
255+ private async Task < byte [ ] > GetVideoResultAsBytesAsync ( SchedulerOptions options , DenseTensor < float > tensorResult , CancellationToken cancellationToken = default )
256+ {
257+ var frameTensors = tensorResult
258+ . Split ( tensorResult . Dimensions [ 0 ] )
259+ . Select ( x => x . ToImageBytes ( ) ) ;
260+
261+ var videoResult = await _videoService . CreateVideoAsync ( frameTensors , options . VideoFPS , cancellationToken ) ;
262+ return videoResult . Data ;
263+ }
264+
265+ private async Task < MemoryStream > GetVideoResultAsStreamAsync ( SchedulerOptions options , DenseTensor < float > tensorResult , CancellationToken cancellationToken = default )
266+ {
267+ return new MemoryStream ( await GetVideoResultAsBytesAsync ( options , tensorResult , cancellationToken ) ) ;
268+ }
269+
240270
241271 }
242272}
0 commit comments