Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit e4f9bca

Browse files
committed
Configurable scheduler sets
1 parent 106a4c4 commit e4f9bca

16 files changed

+179
-107
lines changed

OnnxStack.Core/Extensions/Extensions.cs

Lines changed: 1 addition & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ public static SessionOptions GetSessionOptions(this OnnxModelConfig configuratio
1919
InterOpNumThreads = configuration.InterOpNumThreads.Value,
2020
IntraOpNumThreads = configuration.IntraOpNumThreads.Value
2121
};
22+
2223
switch (configuration.ExecutionProvider)
2324
{
2425
case ExecutionProvider.DirectML:
@@ -87,72 +88,6 @@ public static bool IsNullOrEmpty<TSource>(this IEnumerable<TSource> source)
8788
}
8889

8990

90-
/// <summary>
91-
/// Batches the source sequence into sized buckets.
92-
/// </summary>
93-
/// <typeparam name="TSource">Type of elements in <paramref name="source" /> sequence.</typeparam>
94-
/// <param name="source">The source sequence.</param>
95-
/// <param name="size">Size of buckets.</param>
96-
/// <returns>A sequence of equally sized buckets containing elements of the source collection.</returns>
97-
/// <remarks>
98-
/// This operator uses deferred execution and streams its results (buckets and bucket content).
99-
/// </remarks>
100-
public static IEnumerable<IEnumerable<TSource>> Batch<TSource>(this IEnumerable<TSource> source, int size)
101-
{
102-
return Batch(source, size, x => x);
103-
}
104-
105-
/// <summary>
106-
/// Batches the source sequence into sized buckets and applies a projection to each bucket.
107-
/// </summary>
108-
/// <typeparam name="TSource">Type of elements in <paramref name="source" /> sequence.</typeparam>
109-
/// <typeparam name="TResult">Type of result returned by <paramref name="resultSelector" />.</typeparam>
110-
/// <param name="source">The source sequence.</param>
111-
/// <param name="size">Size of buckets.</param>
112-
/// <param name="resultSelector">The projection to apply to each bucket.</param>
113-
/// <returns>A sequence of projections on equally sized buckets containing elements of the source collection.</returns>
114-
/// <remarks>
115-
/// This operator uses deferred execution and streams its results (buckets and bucket content).
116-
/// </remarks>
117-
public static IEnumerable<TResult> Batch<TSource, TResult>(this IEnumerable<TSource> source, int size, Func<IEnumerable<TSource>, TResult> resultSelector)
118-
{
119-
if (source == null)
120-
throw new ArgumentNullException(nameof(source));
121-
if (size <= 0)
122-
throw new ArgumentOutOfRangeException(nameof(size));
123-
if (resultSelector == null)
124-
throw new ArgumentNullException(nameof(resultSelector));
125-
return BatchImpl(source, size, resultSelector);
126-
}
127-
128-
129-
private static IEnumerable<TResult> BatchImpl<TSource, TResult>(this IEnumerable<TSource> source, int size, Func<IEnumerable<TSource>, TResult> resultSelector)
130-
{
131-
TSource[] bucket = null;
132-
var count = 0;
133-
foreach (var item in source)
134-
{
135-
if (bucket == null)
136-
bucket = new TSource[size];
137-
138-
bucket[count++] = item;
139-
140-
// The bucket is fully buffered before it's yielded
141-
if (count != size)
142-
continue;
143-
144-
// Select is necessary so bucket contents are streamed too
145-
yield return resultSelector(bucket.Select(x => x));
146-
bucket = null;
147-
count = 0;
148-
}
149-
150-
// Return the last bucket with all remaining elements
151-
if (bucket != null && count > 0)
152-
yield return resultSelector(bucket.Take(count));
153-
}
154-
155-
15691
/// <summary>
15792
/// Get the index of the specified item
15893
/// </summary>

OnnxStack.StableDiffusion/Config/SchedulerOptions.cs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using OnnxStack.StableDiffusion.Enums;
22
using System.Collections.Generic;
33
using System.ComponentModel.DataAnnotations;
4+
using System.Text.Json.Serialization;
45

56
namespace OnnxStack.StableDiffusion.Config
67
{
@@ -36,6 +37,7 @@ public record SchedulerOptions
3637
/// If value is set to 0 a random seed is used.
3738
/// </value>
3839
[Range(0, int.MaxValue)]
40+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
3941
public int Seed { get; set; }
4042

4143
/// <summary>
@@ -45,6 +47,7 @@ public record SchedulerOptions
4547
/// The number of steps to run inference for. The more steps the longer it will take to run the inference loop but the image quality should improve.
4648
/// </value>
4749
[Range(5, 200)]
50+
4851
public int InferenceSteps { get; set; } = 30;
4952

5053
/// <summary>
@@ -62,34 +65,76 @@ public record SchedulerOptions
6265
public float Strength { get; set; } = 0.6f;
6366

6467
[Range(0, int.MaxValue)]
68+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
6569
public int TrainTimesteps { get; set; } = 1000;
70+
71+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
6672
public float BetaStart { get; set; } = 0.00085f;
73+
74+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
6775
public float BetaEnd { get; set; } = 0.012f;
76+
77+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
6878
public IEnumerable<float> TrainedBetas { get; set; }
79+
80+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
6981
public TimestepSpacingType TimestepSpacing { get; set; } = TimestepSpacingType.Linspace;
82+
83+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
7084
public BetaScheduleType BetaSchedule { get; set; } = BetaScheduleType.ScaledLinear;
85+
86+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
7187
public int StepsOffset { get; set; } = 0;
88+
89+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
7290
public bool UseKarrasSigmas { get; set; } = false;
91+
92+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
7393
public VarianceType VarianceType { get; set; } = VarianceType.FixedSmall;
94+
95+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
7496
public float SampleMaxValue { get; set; } = 1.0f;
97+
98+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
7599
public bool Thresholding { get; set; } = false;
100+
101+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
76102
public bool ClipSample { get; set; } = false;
103+
104+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
77105
public float ClipSampleRange { get; set; } = 1f;
106+
107+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
78108
public PredictionType PredictionType { get; set; } = PredictionType.Epsilon;
109+
110+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
79111
public AlphaTransformType AlphaTransformType { get; set; } = AlphaTransformType.Cosine;
112+
113+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
80114
public float MaximumBeta { get; set; } = 0.999f;
115+
116+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
81117
public List<int> Timesteps { get; set; }
82118

119+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
83120
public int OriginalInferenceSteps { get; set; } = 50;
84121

122+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
85123
public float AestheticScore { get; set; } = 6f;
124+
125+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
86126
public float AestheticNegativeScore { get; set; } = 2.5f;
87127

128+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
88129
public float ConditioningScale { get; set; } = 0.7f;
89130

131+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
90132
public int InferenceSteps2 { get; set; } = 10;
133+
134+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
91135
public float GuidanceScale2 { get; set; } = 0;
92136

137+
[JsonIgnore]
93138
public bool IsKarrasScheduler
94139
{
95140
get

OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,19 @@ public record StableDiffusionModelSet : IOnnxModelSetConfig
1111
{
1212
public string Name { get; set; }
1313
public bool IsEnabled { get; set; }
14-
public int SampleSize { get; set; } = 512;
15-
public DiffuserPipelineType PipelineType { get; set; }
16-
public List<DiffuserType> Diffusers { get; set; } = new List<DiffuserType>();
17-
public MemoryModeType MemoryMode { get; set; }
1814
public int DeviceId { get; set; }
1915
public int InterOpNumThreads { get; set; }
2016
public int IntraOpNumThreads { get; set; }
2117
public ExecutionMode ExecutionMode { get; set; }
2218
public ExecutionProvider ExecutionProvider { get; set; }
2319
public OnnxModelPrecision Precision { get; set; }
20+
public MemoryModeType MemoryMode { get; set; }
21+
public int SampleSize { get; set; } = 512;
22+
public DiffuserPipelineType PipelineType { get; set; }
23+
public List<DiffuserType> Diffusers { get; set; }
24+
25+
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
26+
public List<SchedulerType> Schedulers { get; set; }
2427

2528
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
2629
public TokenizerModelConfig TokenizerConfig { get; set; }

OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ protected async Task<DenseTensor<float>> DiffuseDecodeAsync(PromptOptions prompt
252252

253253
// Unload if required
254254
if (_memoryMode == MemoryModeType.Minimum)
255-
await _unet.UnloadAsync();
255+
await _decoderUnet.UnloadAsync();
256256

257257
return latents;
258258
}

OnnxStack.StableDiffusion/Models/AutoEncoderModel.cs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using OnnxStack.Core.Config;
1+
using Microsoft.ML.OnnxRuntime;
2+
using OnnxStack.Core.Config;
23
using OnnxStack.Core.Model;
34

45
namespace OnnxStack.StableDiffusion.Models
@@ -13,6 +14,27 @@ public AutoEncoderModel(AutoEncoderModelConfig configuration) : base(configurati
1314
}
1415

1516
public float ScaleFactor => _configuration.ScaleFactor;
17+
18+
19+
public static AutoEncoderModel Create(AutoEncoderModelConfig configuration)
20+
{
21+
return new AutoEncoderModel(configuration);
22+
}
23+
24+
public static AutoEncoderModel Create(string modelFile, float scaleFactor, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
25+
{
26+
var configuration = new AutoEncoderModelConfig
27+
{
28+
DeviceId = deviceId,
29+
ExecutionProvider = executionProvider,
30+
ExecutionMode = ExecutionMode.ORT_SEQUENTIAL,
31+
InterOpNumThreads = 0,
32+
IntraOpNumThreads = 0,
33+
OnnxModelPath = modelFile,
34+
ScaleFactor = scaleFactor
35+
};
36+
return new AutoEncoderModel(configuration);
37+
}
1638
}
1739

1840
public record AutoEncoderModelConfig : OnnxModelConfig

OnnxStack.StableDiffusion/Models/ControlNetModel.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public static ControlNetModel Create(string modelFile, ControlNetType type, int
3232
InterOpNumThreads = 0,
3333
IntraOpNumThreads = 0,
3434
OnnxModelPath = modelFile,
35-
Type = type,
35+
Type = type
3636
};
3737
return new ControlNetModel(configuration);
3838
}

OnnxStack.StableDiffusion/Models/TextEncoderModel.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using OnnxStack.Core.Config;
1+
using Microsoft.ML.OnnxRuntime;
2+
using OnnxStack.Core.Config;
23
using OnnxStack.Core.Model;
34

45
namespace OnnxStack.StableDiffusion.Models
@@ -11,6 +12,25 @@ public TextEncoderModel(TextEncoderModelConfig configuration) : base(configurati
1112
{
1213
_configuration = configuration;
1314
}
15+
16+
public static TextEncoderModel Create(TextEncoderModelConfig configuration)
17+
{
18+
return new TextEncoderModel(configuration);
19+
}
20+
21+
public static TextEncoderModel Create(string modelFile, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
22+
{
23+
var configuration = new TextEncoderModelConfig
24+
{
25+
DeviceId = deviceId,
26+
ExecutionProvider = executionProvider,
27+
ExecutionMode = ExecutionMode.ORT_SEQUENTIAL,
28+
InterOpNumThreads = 0,
29+
IntraOpNumThreads = 0,
30+
OnnxModelPath = modelFile
31+
};
32+
return new TextEncoderModel(configuration);
33+
}
1434
}
1535

1636
public record TextEncoderModelConfig : OnnxModelConfig

OnnxStack.StableDiffusion/Models/TokenizerModel.cs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
using OnnxStack.Core.Config;
1+
using Microsoft.ML.OnnxRuntime;
2+
using OnnxStack.Core.Config;
23
using OnnxStack.Core.Model;
4+
using OnnxStack.StableDiffusion.Enums;
35

46
namespace OnnxStack.StableDiffusion.Models
57
{
@@ -16,6 +18,29 @@ public TokenizerModel(TokenizerModelConfig configuration) : base(configuration)
1618
public int TokenizerLength => _configuration.TokenizerLength;
1719
public int PadTokenId => _configuration.PadTokenId;
1820
public int BlankTokenId => _configuration.BlankTokenId;
21+
22+
public static TokenizerModel Create(TokenizerModelConfig configuration)
23+
{
24+
return new TokenizerModel(configuration);
25+
}
26+
27+
public static TokenizerModel Create(string modelFile, int tokenizerLength = 768, int tokenizerLimit = 77, int padTokenId = 49407, int blankTokenId = 49407, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
28+
{
29+
var configuration = new TokenizerModelConfig
30+
{
31+
DeviceId = deviceId,
32+
ExecutionProvider = executionProvider,
33+
ExecutionMode = ExecutionMode.ORT_SEQUENTIAL,
34+
InterOpNumThreads = 0,
35+
IntraOpNumThreads = 0,
36+
OnnxModelPath = modelFile,
37+
PadTokenId = padTokenId,
38+
BlankTokenId = blankTokenId,
39+
TokenizerLength = tokenizerLength,
40+
TokenizerLimit = tokenizerLimit
41+
};
42+
return new TokenizerModel(configuration);
43+
}
1944
}
2045

2146
public record TokenizerModelConfig : OnnxModelConfig

OnnxStack.StableDiffusion/Models/UNetConditionModel.cs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using OnnxStack.Core.Config;
1+
using Microsoft.ML.OnnxRuntime;
2+
using OnnxStack.Core.Config;
23
using OnnxStack.Core.Model;
34
using OnnxStack.StableDiffusion.Enums;
45

@@ -14,6 +15,26 @@ public UNetConditionModel(UNetConditionModelConfig configuration) : base(configu
1415
}
1516

1617
public ModelType ModelType => _configuration.ModelType;
18+
19+
public static UNetConditionModel Create(UNetConditionModelConfig configuration)
20+
{
21+
return new UNetConditionModel(configuration);
22+
}
23+
24+
public static UNetConditionModel Create(string modelFile, ModelType modelType, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
25+
{
26+
var configuration = new UNetConditionModelConfig
27+
{
28+
DeviceId = deviceId,
29+
ExecutionProvider = executionProvider,
30+
ExecutionMode = ExecutionMode.ORT_SEQUENTIAL,
31+
InterOpNumThreads = 0,
32+
IntraOpNumThreads = 0,
33+
OnnxModelPath = modelFile,
34+
ModelType = modelType
35+
};
36+
return new UNetConditionModel(configuration);
37+
}
1738
}
1839

1940

OnnxStack.StableDiffusion/Pipelines/InstaFlowPipeline.cs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Microsoft.Extensions.Logging;
22
using OnnxStack.Core;
33
using OnnxStack.Core.Config;
4+
using OnnxStack.StableDiffusion.Common;
45
using OnnxStack.StableDiffusion.Config;
56
using OnnxStack.StableDiffusion.Diffusers;
67
using OnnxStack.StableDiffusion.Diffusers.InstaFlow;
@@ -25,14 +26,14 @@ public sealed class InstaFlowPipeline : StableDiffusionPipeline
2526
/// <param name="vaeDecoder">The vae decoder.</param>
2627
/// <param name="vaeEncoder">The vae encoder.</param>
2728
/// <param name="logger">The logger.</param>
28-
public InstaFlowPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List<DiffuserType> diffusers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default)
29-
: base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, defaultSchedulerOptions, logger)
29+
public InstaFlowPipeline(PipelineOptions pipelineOptions, TokenizerModel tokenizer, TextEncoderModel textEncoder, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, UNetConditionModel controlNet, List<DiffuserType> diffusers, List<SchedulerType> schedulers, SchedulerOptions defaultSchedulerOptions = default, ILogger logger = default)
30+
: base(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlNet, diffusers, schedulers, defaultSchedulerOptions, logger)
3031
{
3132
_supportedDiffusers = diffusers ?? new List<DiffuserType>
3233
{
3334
DiffuserType.TextToImage
3435
};
35-
_supportedSchedulers = new List<SchedulerType>
36+
_supportedSchedulers = schedulers ?? new List<SchedulerType>
3637
{
3738
SchedulerType.InstaFlow
3839
};
@@ -87,7 +88,7 @@ protected override IDiffuser CreateDiffuser(DiffuserType diffuserType, ControlNe
8788
controlnet = new UNetConditionModel(config.ControlNetUnetConfig.ApplyDefaults(config));
8889

8990
var pipelineOptions = new PipelineOptions(config.Name, config.MemoryMode);
90-
return new InstaFlowPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.SchedulerOptions, logger);
91+
return new InstaFlowPipeline(pipelineOptions, tokenizer, textEncoder, unet, vaeDecoder, vaeEncoder, controlnet, config.Diffusers, config.Schedulers, config.SchedulerOptions, logger);
9192
}
9293

9394

0 commit comments

Comments
 (0)