Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit bc09f68

Browse files
committed
ModelOptions wrapper for passing multiple models to Diffusers
1 parent 92cb75c commit bc09f68

37 files changed

+156
-124
lines changed

OnnxStack.Console/Examples/StableDebug.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ private async Task<bool> GenerateImage(StableDiffusionModelSet model, PromptOpti
7777
{
7878
var timestamp = Stopwatch.GetTimestamp();
7979
var outputFilename = Path.Combine(_outputDirectory, $"{model.Name}_{options.Seed}_{options.SchedulerType}.png");
80-
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
80+
var result = await _stableDiffusionService.GenerateAsImageAsync(new ModelOptions(model), prompt, options);
8181
if (result is not null)
8282
{
8383
await result.SaveAsPngAsync(outputFilename);

OnnxStack.Console/Examples/StableDiffusionBatch.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public async Task RunAsync()
6666
OutputHelpers.WriteConsole($"Image: {progress.BatchValue}/{progress.BatchMax} - Step: {progress.StepValue}/{progress.StepMax}", ConsoleColor.Cyan);
6767
};
6868

69-
await foreach (var result in _stableDiffusionService.GenerateBatchAsync(model, promptOptions, schedulerOptions, batchOptions, default))
69+
await foreach (var result in _stableDiffusionService.GenerateBatchAsync(new ModelOptions(model), promptOptions, schedulerOptions, batchOptions, default))
7070
{
7171
var outputFilename = Path.Combine(_outputDirectory, $"{batchIndex}_{result.SchedulerOptions.Seed}.png");
7272
var image = result.ImageResult.ToImage();

OnnxStack.Console/Examples/StableDiffusionExample.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ public async Task RunAsync()
7070
private async Task<bool> GenerateImage(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options)
7171
{
7272
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}.png");
73-
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
73+
var result = await _stableDiffusionService.GenerateAsImageAsync(new ModelOptions(model), prompt, options);
7474
if (result == null)
7575
return false;
7676

OnnxStack.Console/Examples/StableDiffusionGenerator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public async Task RunAsync()
6767
private async Task<bool> GenerateImage(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, string key)
6868
{
6969
var outputFilename = Path.Combine(_outputDirectory, $"{options.Seed}_{options.SchedulerType}_{key}.png");
70-
var result = await _stableDiffusionService.GenerateAsImageAsync(model, prompt, options);
70+
var result = await _stableDiffusionService.GenerateAsImageAsync(new ModelOptions(model), prompt, options);
7171
if (result == null)
7272
return false;
7373

OnnxStack.Console/Examples/StableDiffusionGif.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ public async Task RunAsync()
102102

103103
// Set prompt Image, Run Diffusion
104104
promptOptions.InputImage = new InputImage(mergedFrame.CloneAs<Rgba32>());
105-
var result = await _stableDiffusionService.GenerateAsImageAsync(model, promptOptions, schedulerOptions);
105+
var result = await _stableDiffusionService.GenerateAsImageAsync(new ModelOptions(model), promptOptions, schedulerOptions);
106106

107107
// Save Debug Output
108108
await result.SaveAsPngAsync(Path.Combine(_outputDirectory, $"Debug-Output.png"));

OnnxStack.StableDiffusion/Common/IStableDiffusionService.cs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
2+
using OnnxStack.Core.Config;
23
using OnnxStack.StableDiffusion.Config;
34
using OnnxStack.StableDiffusion.Models;
45
using SixLabors.ImageSharp;
@@ -18,15 +19,15 @@ public interface IStableDiffusionService
1819
/// </summary>
1920
/// <param name="modelOptions">The model options.</param>
2021
/// <returns></returns>
21-
Task<bool> LoadModelAsync(StableDiffusionModelSet model);
22+
Task<bool> LoadModelAsync(IOnnxModelSetConfig model);
2223

2324

2425
/// <summary>
2526
/// Unloads the model.
2627
/// </summary>
2728
/// <param name="modelOptions">The model options.</param>
2829
/// <returns></returns>
29-
Task<bool> UnloadModelAsync(StableDiffusionModelSet model);
30+
Task<bool> UnloadModelAsync(IOnnxModel model);
3031

3132
/// <summary>
3233
/// Determines whether the specified model is loaded
@@ -35,7 +36,7 @@ public interface IStableDiffusionService
3536
/// <returns>
3637
/// <c>true</c> if the specified model is loaded; otherwise, <c>false</c>.
3738
/// </returns>
38-
bool IsModelLoaded(StableDiffusionModelSet model);
39+
bool IsModelLoaded(IOnnxModel model);
3940

4041
/// <summary>
4142
/// Generates the StableDiffusion image using the prompt and options provided.
@@ -45,7 +46,7 @@ public interface IStableDiffusionService
4546
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
4647
/// <param name="cancellationToken">The cancellation token.</param>
4748
/// <returns>The diffusion result as <see cref="DenseTensor<float>"/></returns>
48-
Task<DenseTensor<float>> GenerateAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
49+
Task<DenseTensor<float>> GenerateAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
4950

5051
/// <summary>
5152
/// Generates the StableDiffusion image using the prompt and options provided.
@@ -55,7 +56,7 @@ public interface IStableDiffusionService
5556
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
5657
/// <param name="cancellationToken">The cancellation token.</param>
5758
/// <returns>The diffusion result as <see cref="SixLabors.ImageSharp.Image<Rgba32>"/></returns>
58-
Task<Image<Rgba32>> GenerateAsImageAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
59+
Task<Image<Rgba32>> GenerateAsImageAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
5960

6061
/// <summary>
6162
/// Generates the StableDiffusion image using the prompt and options provided.
@@ -65,7 +66,7 @@ public interface IStableDiffusionService
6566
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
6667
/// <param name="cancellationToken">The cancellation token.</param>
6768
/// <returns>The diffusion result as <see cref="byte[]"/></returns>
68-
Task<byte[]> GenerateAsBytesAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
69+
Task<byte[]> GenerateAsBytesAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
6970

7071
/// <summary>
7172
/// Generates the StableDiffusion image using the prompt and options provided.
@@ -75,7 +76,7 @@ public interface IStableDiffusionService
7576
/// <param name="progressCallback">The callback used to provide progess of the current InferenceSteps.</param>
7677
/// <param name="cancellationToken">The cancellation token.</param>
7778
/// <returns>The diffusion result as <see cref="System.IO.Stream"/></returns>
78-
Task<Stream> GenerateAsStreamAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
79+
Task<Stream> GenerateAsStreamAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
7980

8081
/// <summary>
8182
/// Generates a batch of StableDiffusion image using the prompt and options provided.
@@ -87,7 +88,7 @@ public interface IStableDiffusionService
8788
/// <param name="progressCallback">The progress callback.</param>
8889
/// <param name="cancellationToken">The cancellation token.</param>
8990
/// <returns></returns>
90-
IAsyncEnumerable<BatchResult> GenerateBatchAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
91+
IAsyncEnumerable<BatchResult> GenerateBatchAsync(ModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
9192

9293
/// <summary>
9394
/// Generates a batch of StableDiffusion image using the prompt and options provided.
@@ -99,7 +100,7 @@ public interface IStableDiffusionService
99100
/// <param name="progressCallback">The progress callback.</param>
100101
/// <param name="cancellationToken">The cancellation token.</param>
101102
/// <returns></returns>
102-
IAsyncEnumerable<Image<Rgba32>> GenerateBatchAsImageAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
103+
IAsyncEnumerable<Image<Rgba32>> GenerateBatchAsImageAsync(ModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
103104

104105
/// <summary>
105106
/// Generates a batch of StableDiffusion image using the prompt and options provided.
@@ -111,7 +112,7 @@ public interface IStableDiffusionService
111112
/// <param name="progressCallback">The progress callback.</param>
112113
/// <param name="cancellationToken">The cancellation token.</param>
113114
/// <returns></returns>
114-
IAsyncEnumerable<byte[]> GenerateBatchAsBytesAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
115+
IAsyncEnumerable<byte[]> GenerateBatchAsBytesAsync(ModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
115116

116117
/// <summary>
117118
/// Generates a batch of StableDiffusion image using the prompt and options provided.
@@ -123,6 +124,6 @@ public interface IStableDiffusionService
123124
/// <param name="progressCallback">The progress callback.</param>
124125
/// <param name="cancellationToken">The cancellation token.</param>
125126
/// <returns></returns>
126-
IAsyncEnumerable<Stream> GenerateBatchAsStreamAsync(StableDiffusionModelSet model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
127+
IAsyncEnumerable<Stream> GenerateBatchAsStreamAsync(ModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
127128
}
128129
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using Microsoft.ML.OnnxRuntime;
2+
using OnnxStack.Core.Config;
3+
using System.Collections.Generic;
4+
5+
namespace OnnxStack.StableDiffusion.Config
6+
{
7+
public record ControlNetModelSet : IOnnxModelSetConfig
8+
{
9+
public string Name { get; set; }
10+
public bool IsEnabled { get; set; }
11+
public int DeviceId { get; set; }
12+
public int InterOpNumThreads { get; set; }
13+
public int IntraOpNumThreads { get; set; }
14+
public ExecutionMode ExecutionMode { get; set; }
15+
public ExecutionProvider ExecutionProvider { get; set; }
16+
public List<OnnxModelConfig> ModelConfigurations { get; set; }
17+
}
18+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using OnnxStack.StableDiffusion.Enums;
2+
3+
namespace OnnxStack.StableDiffusion.Config
4+
{
5+
public record ModelOptions(StableDiffusionModelSet BaseModel, ControlNetModelSet ControlNetModel = default)
6+
{
7+
public string Name => BaseModel.Name;
8+
public DiffuserPipelineType PipelineType => BaseModel.PipelineType;
9+
public float ScaleFactor => BaseModel.ScaleFactor;
10+
public ModelType ModelType => BaseModel.ModelType;
11+
}
12+
}

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public DiffuserBase(IOnnxModelService onnxModelService, IPromptService promptSer
7575
/// <param name="scheduler">The scheduler.</param>
7676
/// <param name="timesteps">The timesteps.</param>
7777
/// <returns></returns>
78-
protected abstract Task<DenseTensor<float>> PrepareLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps);
78+
protected abstract Task<DenseTensor<float>> PrepareLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps);
7979

8080

8181
/// <summary>
@@ -89,7 +89,7 @@ public DiffuserBase(IOnnxModelService onnxModelService, IPromptService promptSer
8989
/// <param name="progressCallback">The progress callback.</param>
9090
/// <param name="cancellationToken">The cancellation token.</param>
9191
/// <returns></returns>
92-
protected abstract Task<DenseTensor<float>> SchedulerStepAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
92+
protected abstract Task<DenseTensor<float>> SchedulerStepAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, PromptEmbeddingsResult promptEmbeddings, bool performGuidance, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default);
9393

9494

9595
/// <summary>
@@ -100,7 +100,7 @@ public DiffuserBase(IOnnxModelService onnxModelService, IPromptService promptSer
100100
/// <param name="progress">The progress.</param>
101101
/// <param name="cancellationToken">The cancellation token.</param>
102102
/// <returns></returns>
103-
public virtual async Task<DenseTensor<float>> DiffuseAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
103+
public virtual async Task<DenseTensor<float>> DiffuseAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<DiffusionProgress> progressCallback = null, CancellationToken cancellationToken = default)
104104
{
105105
// Create random seed if none was set
106106
schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next();
@@ -112,7 +112,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(StableDiffusionModelS
112112
var performGuidance = ShouldPerformGuidance(schedulerOptions);
113113

114114
// Process prompts
115-
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance);
115+
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions.BaseModel, promptOptions, performGuidance);
116116

117117
// If video input, process frames
118118
if (promptOptions.HasInputVideo)
@@ -157,7 +157,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(StableDiffusionModelS
157157
/// <param name="cancellationToken">The cancellation token.</param>
158158
/// <returns></returns>
159159
/// <exception cref="System.NotImplementedException"></exception>
160-
public virtual async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(StableDiffusionModelSet modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
160+
public virtual async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(ModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, BatchOptions batchOptions, Action<DiffusionProgress> progressCallback = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
161161
{
162162
// Create random seed if none was set
163163
schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next();
@@ -170,7 +170,7 @@ public virtual async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(StableDiffu
170170
var performGuidance = ShouldPerformGuidance(schedulerOptions);
171171

172172
// Process prompts
173-
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance);
173+
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions.BaseModel, promptOptions, performGuidance);
174174

175175
// Generate batch options
176176
var batchSchedulerOptions = BatchGenerator.GenerateBatch(modelOptions, batchOptions, schedulerOptions);
@@ -231,21 +231,21 @@ protected virtual DenseTensor<float> PerformGuidance(DenseTensor<float> noisePre
231231
/// <param name="options">The options.</param>
232232
/// <param name="latents">The latents.</param>
233233
/// <returns></returns>
234-
protected virtual async Task<DenseTensor<float>> DecodeLatentsAsync(StableDiffusionModelSet model, PromptOptions prompt, SchedulerOptions options, DenseTensor<float> latents)
234+
protected virtual async Task<DenseTensor<float>> DecodeLatentsAsync(ModelOptions model, PromptOptions prompt, SchedulerOptions options, DenseTensor<float> latents)
235235
{
236236
var timestamp = _logger.LogBegin();
237237

238238
// Scale and decode the image latents with vae.
239239
latents = latents.MultiplyBy(1.0f / model.ScaleFactor);
240240

241241
var outputDim = new[] { 1, 3, options.Height, options.Width };
242-
var metadata = _onnxModelService.GetModelMetadata(model, OnnxModelType.VaeDecoder);
242+
var metadata = _onnxModelService.GetModelMetadata(model.BaseModel, OnnxModelType.VaeDecoder);
243243
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
244244
{
245245
inferenceParameters.AddInputTensor(latents);
246246
inferenceParameters.AddOutputBuffer(outputDim);
247247

248-
var results = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inferenceParameters);
248+
var results = await _onnxModelService.RunInferenceAsync(model.BaseModel, OnnxModelType.VaeDecoder, inferenceParameters);
249249
using (var imageResult = results.First())
250250
{
251251
_logger?.LogEnd("Latents decoded", timestamp);

0 commit comments

Comments
 (0)