Skip to content

Commit 6403c90

Browse files
committed
Support multi-order schedulers
1 parent a6e62d1 commit 6403c90

File tree

17 files changed

+1160
-18
lines changed

17 files changed

+1160
-18
lines changed

TensorStack.Common/Extensions/TensorExtensions.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,23 @@ public static Tensor<float> SumTensors(this Tensor<float>[] tensors, ReadOnlySpa
311311
}
312312

313313

314+
/// <summary>
315+
/// Clips to the specified minimum/maximum value.
316+
/// </summary>
317+
/// <param name="tensor">The tensor.</param>
318+
/// <param name="minValue">The minimum value.</param>
319+
/// <param name="maxValue">The maximum value.</param>
320+
public static Tensor<float> ClipTo(this Tensor<float> tensor, float minValue, float maxValue)
321+
{
322+
var clipTensor = new Tensor<float>(tensor.Dimensions);
323+
for (int i = 0; i < tensor.Length; i++)
324+
{
325+
clipTensor.SetValue(i, Math.Clamp(tensor.Memory.Span[i], minValue, maxValue));
326+
}
327+
return clipTensor;
328+
}
329+
330+
314331
/// <summary>
315332
/// Reshapes to new tensor.
316333
/// </summary>
Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) TensorStack. All rights reserved.
22
// Licensed under the Apache 2.0 License.
3+
using System;
34
using System.Diagnostics;
45
using TensorStack.Common.Pipeline;
56
using TensorStack.Common.Tensor;
@@ -12,21 +13,24 @@ public GenerateProgress() { }
1213
public GenerateProgress(string message)
1314
{
1415
Message = message;
16+
Type = ProgressType.Message;
1517
}
1618
public GenerateProgress(long elapsed)
1719
{
18-
Elapsed = Stopwatch.GetElapsedTime(elapsed).TotalMilliseconds;
20+
Elapsed = Stopwatch.GetElapsedTime(elapsed);
1921
}
22+
public ProgressType Type { get; set; }
2023
public string Message { get; set; }
24+
public TimeSpan Elapsed { get; set; }
2125

22-
public int BatchMax { get; set; }
23-
public int BatchValue { get; set; }
24-
public Tensor<float> BatchTensor { get; set; }
26+
public int Max { get; set; }
27+
public int Value { get; set; }
28+
public Tensor<float> Tensor { get; set; }
2529

26-
public int StepMax { get; set; }
27-
public int StepValue { get; set; }
28-
public Tensor<float> StepTensor { get; set; }
29-
30-
public double Elapsed { get; set; }
30+
public enum ProgressType
31+
{
32+
Message = 0,
33+
Step = 1
34+
}
3135
}
3236
}

TensorStack.StableDiffusion/Enums/SchedulerType.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,17 @@ public enum SchedulerType
1515
[Display(Name = "Euler Ancestral")]
1616
EulerAncestral = 2,
1717

18+
[Display(Name = "DDPM")]
19+
DDPM = 3,
1820

21+
[Display(Name = "DDIM")]
22+
DDIM = 4,
23+
24+
[Display(Name = "KDPM2")]
25+
KDPM2 = 5,
26+
27+
[Display(Name = "KDPM2-Ancestral")]
28+
KDPM2Ancestral = 6,
1929

2030
[Display(Name = "LCM")]
2131
LCM = 20

TensorStack.StableDiffusion/Extensions.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using Microsoft.Extensions.Logging;
44
using System;
55
using System.Diagnostics;
6+
using TensorStack.Common.Tensor;
67
using TensorStack.StableDiffusion.Common;
78

89
namespace TensorStack.StableDiffusion
@@ -20,6 +21,39 @@ public static void Notify(this IProgress<GenerateProgress> progressCallback, str
2021
}
2122

2223

24+
/// <summary>
25+
/// Notifies the specified message.
26+
/// </summary>
27+
/// <param name="progressCallback">The progress callback.</param>
28+
/// <param name="message">The message.</param>
29+
public static void Notify(this IProgress<GenerateProgress> progressCallback, GenerateProgress message)
30+
{
31+
progressCallback?.Report(message);
32+
}
33+
34+
35+
/// <summary>
36+
/// Notifies the specified step.
37+
/// </summary>
38+
/// <param name="progressCallback">The progress callback.</param>
39+
/// <param name="step">The step.</param>
40+
/// <param name="steps">The steps.</param>
41+
/// <param name="latents">The latents.</param>
42+
/// <param name="elapsed">The elapsed.</param>
43+
public static void Notify(this IProgress<GenerateProgress> progressCallback, int step, int steps, Tensor<float> latents, long elapsed)
44+
{
45+
progressCallback?.Report(new GenerateProgress
46+
{
47+
Max = steps,
48+
Value = step,
49+
Tensor = latents.Clone(),
50+
Type = GenerateProgress.ProgressType.Step,
51+
Message = $"Step: {step:D2}/{steps:D2}",
52+
Elapsed = elapsed > 0 ? Stopwatch.GetElapsedTime(elapsed) : TimeSpan.Zero
53+
});
54+
}
55+
56+
2357
/// <summary>
2458
/// Log and return timestamp.
2559
/// </summary>

TensorStack.StableDiffusion/Pipelines/PipelineBase.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ protected virtual IScheduler CreateScheduler(GenerateOptions options)
106106
SchedulerType.LMS => new LMSScheduler(options),
107107
SchedulerType.Euler => new EulerScheduler(options),
108108
SchedulerType.EulerAncestral => new EulerAncestralScheduler(options),
109+
SchedulerType.DDPM => new DDPMScheduler(options),
110+
SchedulerType.DDIM => new DDIMScheduler(options),
111+
SchedulerType.KDPM2 => new KDPM2Scheduler(options),
112+
SchedulerType.KDPM2Ancestral => new KDPM2AncestralScheduler(options),
109113
SchedulerType.LCM => new LCMScheduler(options),
110114
_ => default
111115
};
@@ -140,5 +144,6 @@ protected Tensor<float> ApplyGuidance(Tensor<float> conditional, Tensor<float> u
140144
unconditional.Memory.Lerp(conditional.Memory, guidanceScale);
141145
return unconditional;
142146
}
147+
143148
}
144149
}

TensorStack.StableDiffusion/Pipelines/StableDiffusion/StableDiffusionPipeline.cs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ private async Task<Tensor<float>> RunInferenceAsync(IPipelineOptions options, IS
274274
// Result
275275
latents = stepResult.Sample;
276276

277+
// Progress
278+
if (scheduler.IsFinalOrder)
279+
progressCallback.Notify(scheduler.CurrentStep, scheduler.TotalSteps, latents, steptime);
280+
277281
Logger.LogEnd(LogLevel.Debug, steptime, $"[RunInferenceAsync] Step: {i + 1}/{timesteps.Count}");
278282
}
279283

@@ -346,6 +350,10 @@ private async Task<Tensor<float>> RunInferenceAsync(IPipelineOptions options, Co
346350
// Result
347351
latents = stepResult.Sample;
348352

353+
// Progress
354+
if (scheduler.IsFinalOrder)
355+
progressCallback.Notify(scheduler.CurrentStep, scheduler.TotalSteps, latents, steptime);
356+
349357
Logger.LogEnd(LogLevel.Debug, steptime, $"[RunInferenceAsync] Step: {i + 1}/{timesteps.Count}");
350358
}
351359

@@ -463,7 +471,17 @@ protected override async Task CheckPipelineState(IPipelineOptions options)
463471
/// </summary>
464472
protected override IReadOnlyList<SchedulerType> ConfigureSchedulers()
465473
{
466-
return [SchedulerType.LMS, SchedulerType.Euler, SchedulerType.EulerAncestral, SchedulerType.LCM];
474+
return
475+
[
476+
SchedulerType.LMS,
477+
SchedulerType.Euler,
478+
SchedulerType.EulerAncestral,
479+
SchedulerType.DDPM,
480+
SchedulerType.DDIM,
481+
SchedulerType.KDPM2,
482+
SchedulerType.KDPM2Ancestral,
483+
SchedulerType.LCM
484+
];
467485
}
468486

469487

@@ -478,7 +496,7 @@ protected override GenerateOptions ConfigureDefaultOptions()
478496
Width = 512,
479497
Height = 512,
480498
GuidanceScale = 7.5f,
481-
Scheduler = SchedulerType.Euler,
499+
Scheduler = SchedulerType.EulerAncestral,
482500
TimestepSpacing = TimestepSpacingType.Trailing
483501
};
484502
}

TensorStack.StableDiffusion/Pipelines/StableDiffusion2/StableDiffusion2Pipeline.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public StableDiffusion2Pipeline(StableDiffusion2Config configuration, ILogger lo
6464
/// </summary>
6565
protected override IReadOnlyList<SchedulerType> ConfigureSchedulers()
6666
{
67-
return [SchedulerType.LMS, SchedulerType.Euler, SchedulerType.EulerAncestral, SchedulerType.LCM];
67+
return [SchedulerType.DDPM, SchedulerType.DDIM];
6868
}
6969

7070

@@ -79,7 +79,7 @@ protected override GenerateOptions ConfigureDefaultOptions()
7979
Width = 768,
8080
Height = 768,
8181
GuidanceScale = 7.5f,
82-
Scheduler = SchedulerType.Euler,
82+
Scheduler = SchedulerType.DDPM,
8383
PredictionType = PredictionType.VariablePrediction
8484
};
8585
}

0 commit comments

Comments
 (0)