55using OnnxStack . StableDiffusion . Config ;
66using OnnxStack . StableDiffusion . Enums ;
77using OnnxStack . StableDiffusion . Helpers ;
8+ using OnnxStack . StableDiffusion . Models ;
89using SixLabors . ImageSharp ;
910using SixLabors . ImageSharp . PixelFormats ;
1011using System ;
@@ -150,7 +151,7 @@ public async Task<Stream> GenerateAsStreamAsync(IModelOptions model, PromptOptio
150151 /// <param name="progressCallback">The progress callback.</param>
151152 /// <param name="cancellationToken">The cancellation token.</param>
152153 /// <returns></returns>
153- public IAsyncEnumerable < DenseTensor < float > > GenerateBatchAsync ( IModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < int , int , int , int > progressCallback = null , CancellationToken cancellationToken = default )
154+ public IAsyncEnumerable < BatchResult > GenerateBatchAsync ( IModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < int , int , int , int > progressCallback = null , CancellationToken cancellationToken = default )
154155 {
155156 return DiffuseBatchAsync ( modelOptions , promptOptions , schedulerOptions , batchOptions , progressCallback , cancellationToken ) ;
156157 }
@@ -169,7 +170,7 @@ public IAsyncEnumerable<DenseTensor<float>> GenerateBatchAsync(IModelOptions mod
169170 public async IAsyncEnumerable < Image < Rgba32 > > GenerateBatchAsImageAsync ( IModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < int , int , int , int > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
170171 {
171172 await foreach ( var result in GenerateBatchAsync ( modelOptions , promptOptions , schedulerOptions , batchOptions , progressCallback , cancellationToken ) )
172- yield return result . ToImage ( ) ;
173+ yield return result . ImageResult . ToImage ( ) ;
173174 }
174175
175176
@@ -186,7 +187,7 @@ public async IAsyncEnumerable<Image<Rgba32>> GenerateBatchAsImageAsync(IModelOpt
186187 public async IAsyncEnumerable < byte [ ] > GenerateBatchAsBytesAsync ( IModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < int , int , int , int > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
187188 {
188189 await foreach ( var result in GenerateBatchAsync ( modelOptions , promptOptions , schedulerOptions , batchOptions , progressCallback , cancellationToken ) )
189- yield return result . ToImageBytes ( ) ;
190+ yield return result . ImageResult . ToImageBytes ( ) ;
190191 }
191192
192193
@@ -203,7 +204,7 @@ public async IAsyncEnumerable<byte[]> GenerateBatchAsBytesAsync(IModelOptions mo
203204 public async IAsyncEnumerable < Stream > GenerateBatchAsStreamAsync ( IModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < int , int , int , int > progressCallback = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
204205 {
205206 await foreach ( var result in GenerateBatchAsync ( modelOptions , promptOptions , schedulerOptions , batchOptions , progressCallback , cancellationToken ) )
206- yield return result . ToImageStream ( ) ;
207+ yield return result . ImageResult . ToImageStream ( ) ;
207208 }
208209
209210
@@ -220,7 +221,7 @@ private async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOptions,
220221 }
221222
222223
223- private IAsyncEnumerable < DenseTensor < float > > DiffuseBatchAsync ( IModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < int , int , int , int > progress = null , CancellationToken cancellationToken = default )
224+ private IAsyncEnumerable < BatchResult > DiffuseBatchAsync ( IModelOptions modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , BatchOptions batchOptions , Action < int , int , int , int > progress = null , CancellationToken cancellationToken = default )
224225 {
225226 if ( ! _pipelines . TryGetValue ( modelOptions . PipelineType , out var pipeline ) )
226227 throw new Exception ( "Pipeline not found or is unsupported" ) ;
0 commit comments