Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit d8124a3

Browse files
committed
Add ProgressCallback and Cancellation support
1 parent 0c12c6e commit d8124a3

File tree

4 files changed

+31
-6
lines changed

4 files changed

+31
-6
lines changed

OnnxStack.StableDiffusion/Common/ISchedulerService.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
22
using OnnxStack.StableDiffusion.Config;
3+
using System;
4+
using System.Threading;
35
using System.Threading.Tasks;
46

57
namespace OnnxStack.StableDiffusion.Common
@@ -13,6 +15,6 @@ public interface ISchedulerService
1315
/// <param name="prompt">The prompt.</param>
1416
/// <param name="options">The options.</param>
1517
/// <returns></returns>
16-
Task<DenseTensor<float>> RunAsync(PromptOptions prompt, SchedulerOptions options);
18+
Task<DenseTensor<float>> RunAsync(PromptOptions prompt, SchedulerOptions options, Action<int, int> progress = null, CancellationToken cancellationToken = default);
1719
}
1820
}

OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using OnnxStack.StableDiffusion.Config;
22
using OnnxStack.StableDiffusion.Results;
3+
using System;
34
using System.Collections.Generic;
5+
using System.Threading;
46
using System.Threading.Tasks;
57

68
namespace OnnxStack.StableDiffusion.Common
@@ -9,8 +11,10 @@ public interface IStableDiffusionService
911
{
1012
Task<ImageResult> TextToImage(PromptOptions prompt);
1113
Task<ImageResult> TextToImage(PromptOptions prompt, SchedulerOptions options);
14+
Task<ImageResult> TextToImage(PromptOptions prompt, SchedulerOptions options, Action<int, int> progress = null, CancellationToken cancellationToken = default);
1215

1316
Task<ImageResult> TextToImageFile(PromptOptions prompt, string outputFile);
1417
Task<ImageResult> TextToImageFile(PromptOptions prompt, SchedulerOptions options, string outputFile);
18+
Task<ImageResult> TextToImageFile(PromptOptions prompt, SchedulerOptions options, string outputFile, Action<int, int> progress = null, CancellationToken cancellationToken = default);
1519
}
1620
}

OnnxStack.StableDiffusion/Services/SchedulerService.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using System;
1212
using System.Collections.Generic;
1313
using System.Linq;
14+
using System.Threading;
1415
using System.Threading.Tasks;
1516

1617

@@ -41,7 +42,7 @@ public SchedulerService(OnnxStackConfig configuration, IOnnxModelService onnxMod
4142
/// <param name="promptOptions">The options.</param>
4243
/// <param name="schedulerOptions">The scheduler configuration.</param>
4344
/// <returns></returns>
44-
public async Task<DenseTensor<float>> RunAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions)
45+
public async Task<DenseTensor<float>> RunAsync(PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progress = null, CancellationToken cancellationToken = default)
4546
{
4647
// Create random seed if none was set
4748
schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next();
@@ -63,6 +64,8 @@ public async Task<DenseTensor<float>> RunAsync(PromptOptions promptOptions, Sche
6364
var step = 0;
6465
foreach (var timestep in timesteps)
6566
{
67+
cancellationToken.ThrowIfCancellationRequested();
68+
6669
// Create input tensor.
6770
var inputTensor = scheduler.ScaleInput(latentSample.Duplicate(schedulerOptions.GetScaledDimension(2)), timestep);
6871

@@ -90,6 +93,7 @@ public async Task<DenseTensor<float>> RunAsync(PromptOptions promptOptions, Sche
9093
}
9194

9295
Console.WriteLine($"Step: {++step}/{timesteps.Count}");
96+
progress?.Invoke(step, timesteps.Count);
9397
}
9498

9599
// Decode Latents

OnnxStack.StableDiffusion/Services/StableDiffusionService.cs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
using OnnxStack.StableDiffusion.Config;
33
using OnnxStack.StableDiffusion.Helpers;
44
using OnnxStack.StableDiffusion.Results;
5+
using System;
6+
using System.Threading;
57
using System.Threading.Tasks;
68

79
namespace OnnxStack.StableDiffusion.Services
@@ -25,6 +27,14 @@ public Task<ImageResult> TextToImage(PromptOptions prompt, SchedulerOptions opti
2527
return TextToImageInternal(prompt, options);
2628
}
2729

30+
public Task<ImageResult> TextToImage(PromptOptions prompt, SchedulerOptions options, Action<int, int> progress = null, CancellationToken cancellationToken = default)
31+
{
32+
return TextToImageInternal(prompt, options, progress, cancellationToken);
33+
}
34+
35+
36+
37+
2838
public Task<ImageResult> TextToImageFile(PromptOptions prompt, string outputFile)
2939
{
3040
return TextToImageFileInternal(prompt, new SchedulerOptions(), outputFile);
@@ -35,16 +45,21 @@ public Task<ImageResult> TextToImageFile(PromptOptions prompt, SchedulerOptions
3545
return TextToImageFileInternal(prompt, options, outputFile);
3646
}
3747

48+
public Task<ImageResult> TextToImageFile(PromptOptions prompt, SchedulerOptions options, string outputFile, Action<int, int> progress = null, CancellationToken cancellationToken = default)
49+
{
50+
return TextToImageFileInternal(prompt, options, outputFile, progress, cancellationToken);
51+
}
52+
3853

39-
private async Task<ImageResult> TextToImageInternal(PromptOptions prompt, SchedulerOptions options)
54+
private async Task<ImageResult> TextToImageInternal(PromptOptions prompt, SchedulerOptions options, Action<int, int> progress = null, CancellationToken cancellationToken = default)
4055
{
41-
var imageTensorData = await _schedulerService.RunAsync(prompt, options).ConfigureAwait(false);
56+
var imageTensorData = await _schedulerService.RunAsync(prompt, options, progress, cancellationToken).ConfigureAwait(false);
4257
return ImageHelpers.TensorToImage(options, imageTensorData);
4358
}
4459

45-
private async Task<ImageResult> TextToImageFileInternal(PromptOptions prompt, SchedulerOptions options, string outputFile)
60+
private async Task<ImageResult> TextToImageFileInternal(PromptOptions prompt, SchedulerOptions options, string outputFile, Action<int, int> progress = null, CancellationToken cancellationToken = default)
4661
{
47-
var result = await TextToImageInternal(prompt, options);
62+
var result = await TextToImageInternal(prompt, options, progress, cancellationToken);
4863
if (result is null)
4964
return null;
5065

0 commit comments

Comments
 (0)