Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 1ecf164

Browse files
committed
Scheduler batching
1 parent 5b5072e commit 1ecf164

File tree

10 files changed

+62
-64
lines changed

10 files changed

+62
-64
lines changed

OnnxStack.Console/Examples/StableDebug.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
using OnnxStack.StableDiffusion.Common;
1+
using OnnxStack.StableDiffusion;
2+
using OnnxStack.StableDiffusion.Common;
23
using OnnxStack.StableDiffusion.Config;
34
using OnnxStack.StableDiffusion.Enums;
4-
using OnnxStack.StableDiffusion.Services;
55
using SixLabors.ImageSharp;
66
using System.Diagnostics;
77

@@ -54,7 +54,7 @@ public async Task RunAsync()
5454
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
5555
await _stableDiffusionService.LoadModel(model);
5656

57-
foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType))
57+
foreach (var schedulerType in model.PipelineType.GetSchedulerTypes())
5858
{
5959
schedulerOptions.SchedulerType = schedulerType;
6060
OutputHelpers.WriteConsole($"Generating {schedulerType} Image...", ConsoleColor.Green);

OnnxStack.Console/Examples/StableDiffusionBatch.cs

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
using OnnxStack.StableDiffusion.Common;
22
using OnnxStack.StableDiffusion.Config;
33
using OnnxStack.StableDiffusion.Enums;
4-
using OnnxStack.StableDiffusion.Helpers;
4+
using OnnxStack.StableDiffusion;
55
using SixLabors.ImageSharp;
6+
using OnnxStack.StableDiffusion.Helpers;
67

78
namespace OnnxStack.Console.Runner
89
{
@@ -47,10 +48,7 @@ public async Task RunAsync()
4748

4849
var batchOptions = new BatchOptions
4950
{
50-
BatchType = BatchOptionType.Guidance,
51-
ValueFrom = 4,
52-
ValueTo = 20,
53-
Increment = 0.5f
51+
BatchType = BatchOptionType.Scheduler
5452
};
5553

5654
foreach (var model in _stableDiffusionService.Models)
@@ -65,16 +63,12 @@ public async Task RunAsync()
6563
OutputHelpers.WriteConsole($"Image: {batch}/{batchCount} - Step: {step}/{steps}", ConsoleColor.Cyan);
6664
};
6765

68-
foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType).Take(1))
66+
await foreach (var result in _stableDiffusionService.GenerateBatchAsync(model, promptOptions, schedulerOptions, batchOptions, callback))
6967
{
70-
schedulerOptions.SchedulerType = schedulerType;
71-
await foreach (var result in _stableDiffusionService.GenerateBatchAsync(model, promptOptions, schedulerOptions, batchOptions, callback))
72-
{
73-
var outputFilename = Path.Combine(_outputDirectory, $"{batchIndex}_{result.SchedulerOptions.Seed}.png");
74-
var image = result.ImageResult.ToImage();
75-
await image.SaveAsPngAsync(outputFilename);
76-
OutputHelpers.WriteConsole($"Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
77-
}
68+
var outputFilename = Path.Combine(_outputDirectory, $"{batchIndex}_{result.SchedulerOptions.Seed}.png");
69+
var image = result.ImageResult.ToImage();
70+
await image.SaveAsPngAsync(outputFilename);
71+
OutputHelpers.WriteConsole($"Image Created: {Path.GetFileName(outputFilename)}", ConsoleColor.Green);
7872
}
7973

8074
OutputHelpers.WriteConsole($"Unloading Model `{model.Name}`...", ConsoleColor.Green);

OnnxStack.Console/Examples/StableDiffusionExample.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
using OnnxStack.Core;
1+
using OnnxStack.StableDiffusion;
22
using OnnxStack.StableDiffusion.Common;
33
using OnnxStack.StableDiffusion.Config;
4-
using OnnxStack.StableDiffusion.Enums;
54
using SixLabors.ImageSharp;
65

76
namespace OnnxStack.Console.Runner
@@ -53,7 +52,7 @@ public async Task RunAsync()
5352
OutputHelpers.WriteConsole($"Loading Model `{model.Name}`...", ConsoleColor.Green);
5453
await _stableDiffusionService.LoadModel(model);
5554

56-
foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType))
55+
foreach (var schedulerType in model.PipelineType.GetSchedulerTypes())
5756
{
5857
schedulerOptions.SchedulerType = schedulerType;
5958
OutputHelpers.WriteConsole($"Generating {schedulerType} Image...", ConsoleColor.Green);

OnnxStack.Console/Examples/StableDiffusionGenerator.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
using OnnxStack.Core;
1+
using OnnxStack.StableDiffusion;
22
using OnnxStack.StableDiffusion.Common;
33
using OnnxStack.StableDiffusion.Config;
4-
using OnnxStack.StableDiffusion.Enums;
54
using SixLabors.ImageSharp;
65
using System.Collections.ObjectModel;
76

@@ -48,7 +47,7 @@ public async Task RunAsync()
4847
{
4948
Seed = Random.Shared.Next()
5049
};
51-
foreach (var schedulerType in Helpers.GetPipelineSchedulers(model.PipelineType))
50+
foreach (var schedulerType in model.PipelineType.GetSchedulerTypes())
5251
{
5352
schedulerOptions.SchedulerType = schedulerType;
5453
OutputHelpers.WriteConsole($"Generating {schedulerType} Image...", ConsoleColor.Green);

OnnxStack.Console/Helpers.cs

Lines changed: 0 additions & 28 deletions
This file was deleted.

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ public async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(IModelOptions model
127127
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance);
128128

129129
// Generate batch options
130-
var batchSchedulerOptions = BatchGenerator.GenerateBatch(batchOptions, schedulerOptions);
130+
var batchSchedulerOptions = BatchGenerator.GenerateBatch(modelOptions, batchOptions, schedulerOptions);
131131

132132
var batchIndex = 1;
133133
var schedulerCallback = (int step, int steps) => progressCallback?.Invoke(batchIndex, batchSchedulerOptions.Count, step, steps);

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public async IAsyncEnumerable<BatchResult> DiffuseBatchAsync(IModelOptions model
124124
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance);
125125

126126
// Generate batch options
127-
var batchSchedulerOptions = BatchGenerator.GenerateBatch(batchOptions, schedulerOptions);
127+
var batchSchedulerOptions = BatchGenerator.GenerateBatch(modelOptions, batchOptions, schedulerOptions);
128128

129129
var batchIndex = 1;
130130
var schedulerCallback = (int step, int steps) => progressCallback?.Invoke(batchIndex, batchSchedulerOptions.Count, step, steps);

OnnxStack.StableDiffusion/Enums/BatchOptionType.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ public enum BatchOptionType
55
Seed = 0,
66
Step = 1,
77
Guidance = 2,
8-
Strength = 3
8+
Strength = 3,
9+
Scheduler = 4
910
}
1011
}

OnnxStack.StableDiffusion/Extensions.cs

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
using Microsoft.ML.OnnxRuntime;
2-
using OnnxStack.Core;
32
using OnnxStack.StableDiffusion.Config;
3+
using OnnxStack.StableDiffusion.Enums;
44
using System;
5-
using System.Collections.Generic;
65
using System.Linq;
76

87
namespace OnnxStack.StableDiffusion
98
{
10-
internal static class Extensions
9+
public static class Extensions
1110
{
1211
/// <summary>
1312
/// Gets the first element and casts it to the specified type.
1413
/// </summary>
1514
/// <typeparam name="T">Desired return type</typeparam>
1615
/// <param name="collection">The collection.</param>
1716
/// <returns>Firts element in the collection cast as <see cref="T"/></returns>
18-
public static T FirstElementAs<T>(this IDisposableReadOnlyCollection<DisposableNamedOnnxValue> collection)
17+
internal static T FirstElementAs<T>(this IDisposableReadOnlyCollection<DisposableNamedOnnxValue> collection)
1918
{
2019
if (collection is null || collection.Count == 0)
2120
return default;
@@ -34,7 +33,7 @@ public static T FirstElementAs<T>(this IDisposableReadOnlyCollection<DisposableN
3433
/// <typeparam name="T">Desired return type</typeparam>
3534
/// <param name="collection">The collection.</param>
3635
/// <returns>Last element in the collection cast as <see cref="T"/></returns>
37-
public static T LastElementAs<T>(this IDisposableReadOnlyCollection<DisposableNamedOnnxValue> collection)
36+
internal static T LastElementAs<T>(this IDisposableReadOnlyCollection<DisposableNamedOnnxValue> collection)
3837
{
3938
if (collection is null || collection.Count == 0)
4039
return default;
@@ -53,7 +52,7 @@ public static T LastElementAs<T>(this IDisposableReadOnlyCollection<DisposableNa
5352
/// <param name="options">The options.</param>
5453
/// <returns></returns>
5554
/// <exception cref="System.ArgumentOutOfRangeException">Width must be divisible by 64</exception>
56-
public static int GetScaledWidth(this SchedulerOptions options)
55+
internal static int GetScaledWidth(this SchedulerOptions options)
5756
{
5857
if (options.Width % 64 > 0)
5958
throw new ArgumentOutOfRangeException(nameof(options.Width), $"{nameof(options.Width)} must be divisible by 64");
@@ -68,7 +67,7 @@ public static int GetScaledWidth(this SchedulerOptions options)
6867
/// <param name="options">The options.</param>
6968
/// <returns></returns>
7069
/// <exception cref="System.ArgumentOutOfRangeException">Height must be divisible by 64</exception>
71-
public static int GetScaledHeight(this SchedulerOptions options)
70+
internal static int GetScaledHeight(this SchedulerOptions options)
7271
{
7372
if (options.Height % 64 > 0)
7473
throw new ArgumentOutOfRangeException(nameof(options.Height), $"{nameof(options.Height)} must be divisible by 64");
@@ -84,9 +83,36 @@ public static int GetScaledHeight(this SchedulerOptions options)
8483
/// <param name="batch">The batch.</param>
8584
/// <param name="channels">The channels.</param>
8685
/// <returns>Tensor dimension of [batch, channels, (Height / 8), (Width / 8)]</returns>
87-
public static int[] GetScaledDimension(this SchedulerOptions options, int batch = 1, int channels = 4)
86+
internal static int[] GetScaledDimension(this SchedulerOptions options, int batch = 1, int channels = 4)
8887
{
8988
return new[] { batch, channels, options.GetScaledHeight(), options.GetScaledWidth() };
9089
}
90+
91+
92+
/// <summary>
93+
/// Gets the pipeline schedulers.
94+
/// </summary>
95+
/// <param name="pipelineType">Type of the pipeline.</param>
96+
/// <returns></returns>
97+
public static SchedulerType[] GetSchedulerTypes(this DiffuserPipelineType pipelineType)
98+
{
99+
return pipelineType switch
100+
{
101+
DiffuserPipelineType.StableDiffusion => new[]
102+
{
103+
SchedulerType.LMS,
104+
SchedulerType.Euler,
105+
SchedulerType.EulerAncestral,
106+
SchedulerType.DDPM,
107+
SchedulerType.DDIM,
108+
SchedulerType.KDPM2
109+
},
110+
DiffuserPipelineType.LatentConsistency => new[]
111+
{
112+
SchedulerType.LCM
113+
},
114+
_ => default
115+
};
116+
}
91117
}
92118
}

OnnxStack.StableDiffusion/Helpers/BatchGenerator.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using OnnxStack.StableDiffusion.Config;
1+
using OnnxStack.StableDiffusion.Common;
2+
using OnnxStack.StableDiffusion.Config;
23
using OnnxStack.StableDiffusion.Enums;
34
using System;
45
using System.Collections.Generic;
@@ -14,7 +15,7 @@ public static class BatchGenerator
1415
/// <param name="batchOptions">The batch options.</param>
1516
/// <param name="schedulerOptions">The scheduler options.</param>
1617
/// <returns></returns>
17-
public static List<SchedulerOptions> GenerateBatch(BatchOptions batchOptions, SchedulerOptions schedulerOptions)
18+
public static List<SchedulerOptions> GenerateBatch(IModelOptions modelOptions, BatchOptions batchOptions, SchedulerOptions schedulerOptions)
1819
{
1920
if (batchOptions.BatchType == BatchOptionType.Seed)
2021
{
@@ -43,6 +44,12 @@ public static List<SchedulerOptions> GenerateBatch(BatchOptions batchOptions, Sc
4344
.Select(x => schedulerOptions with { Strength = batchOptions.ValueFrom + (batchOptions.Increment * x) })
4445
.ToList();
4546
}
47+
else if (batchOptions.BatchType == BatchOptionType.Scheduler)
48+
{
49+
return modelOptions.PipelineType.GetSchedulerTypes()
50+
.Select(x => schedulerOptions with { SchedulerType = x })
51+
.ToList();
52+
}
4653
return new List<SchedulerOptions>();
4754
}
4855
}

0 commit comments

Comments
 (0)