Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 76e34b4

Browse files
committed
Support multi-tensor scheduler results
1 parent f2751fe commit 76e34b4

File tree

12 files changed

+39
-19
lines changed

12 files changed

+39
-19
lines changed

OnnxStack.StableDiffusion/Common/IScheduler.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
22
using OnnxStack.StableDiffusion.Enums;
3+
using OnnxStack.StableDiffusion.Schedulers;
34
using System;
45
using System.Collections.Generic;
56

@@ -38,7 +39,7 @@ public interface IScheduler : IDisposable
3839
/// <param name="sample">The sample.</param>
3940
/// <param name="order">The order.</param>
4041
/// <returns></returns>
41-
DenseTensor<float> Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4);
42+
SchedulerStepResult Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4);
4243

4344
/// <summary>
4445
/// Adds noise to the sample.

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
106106
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
107107

108108
// Scheduler Step
109-
latents = scheduler.Step(noisePred, timestep, latents);
109+
latents = scheduler.Step(noisePred, timestep, latents).PreviousSample;
110110
}
111111

112112
progressCallback?.Invoke(step, timesteps.Count);

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
8787
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
8888

8989
// Scheduler Step
90-
latents = scheduler.Step(noisePred, timestep, latents);
90+
latents = scheduler.Step(noisePred, timestep, latents).PreviousSample;
9191
}
9292

9393
progress?.Invoke(++step, timesteps.Count);

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
8282
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
8383

8484
// Scheduler Step
85-
var steplatents = scheduler.Step(noisePred, timestep, latents);
85+
var steplatents = scheduler.Step(noisePred, timestep, latents).PreviousSample;
8686

8787
// Add noise to original latent
8888
var initLatentsProper = scheduler.AddNoise(latentsOriginal, noise, new[] { timestep });

OnnxStack.StableDiffusion/Schedulers/DDIMScheduler.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int tim
8686
/// <returns></returns>
8787
/// <exception cref="System.ArgumentException">Invalid prediction_type: {SchedulerOptions.PredictionType}</exception>
8888
/// <exception cref="System.NotImplementedException">DDIMScheduler Thresholding currently not implemented</exception>
89-
public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
89+
public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
9090
{
9191
//# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
9292
//# Ideally, read DDIM paper in-detail understanding
@@ -176,7 +176,7 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
176176
if (eta > 0)
177177
prevSample = prevSample.AddTensors(CreateRandomSample(modelOutput.Dimensions).MultipleTensorByFloat(stdDevT));
178178

179-
return prevSample;
179+
return new SchedulerStepResult(prevSample);
180180
}
181181

182182

OnnxStack.StableDiffusion/Schedulers/DDPMScheduler.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int tim
8181
/// <returns></returns>
8282
/// <exception cref="System.ArgumentException">Invalid prediction_type: {SchedulerOptions.PredictionType}</exception>
8383
/// <exception cref="System.NotImplementedException">DDPMScheduler Thresholding currently not implemented</exception>
84-
public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
84+
public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
8585
{
8686
int currentTimestep = timestep;
8787
int previousTimestep = GetPreviousTimestep(currentTimestep);
@@ -159,7 +159,7 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
159159
predPrevSample = predPrevSample.AddTensors(variance);
160160
}
161161

162-
return predPrevSample;
162+
return new SchedulerStepResult(predPrevSample);
163163
}
164164

165165

OnnxStack.StableDiffusion/Schedulers/EulerAncestralScheduler.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int tim
103103
/// <param name="sample">The sample.</param>
104104
/// <param name="order">The order.</param>
105105
/// <returns></returns>
106-
public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
106+
public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
107107
{
108108
var stepIndex = Timesteps.IndexOf(timestep);
109109
var sigma = _sigmas[stepIndex];
@@ -129,7 +129,8 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
129129
var delta = sigmaDown - sigma;
130130
var prevSample = sample.AddTensors(derivative.MultipleTensorByFloat(delta));
131131
var noise = CreateRandomSample(prevSample.Dimensions);
132-
return prevSample.AddTensors(noise.MultipleTensorByFloat(sigmaUp));
132+
prevSample = prevSample.AddTensors(noise.MultipleTensorByFloat(sigmaUp));
133+
return new SchedulerStepResult(prevSample);
133134
}
134135

135136

OnnxStack.StableDiffusion/Schedulers/EulerScheduler.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int tim
108108
/// <param name="sample">The sample.</param>
109109
/// <param name="order">The order.</param>
110110
/// <returns></returns>
111-
public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
111+
public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
112112
{
113113
// TODO: Implement "extended settings for scheduler types"
114114
float s_churn = 0f;
@@ -140,7 +140,7 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
140140
.DivideTensorByFloat(sigmaHat);
141141

142142
var delta = _sigmas[stepIndex + 1] - sigmaHat;
143-
return sample.AddTensors(derivative.MultipleTensorByFloat(delta));
143+
return new SchedulerStepResult(sample.AddTensors(derivative.MultipleTensorByFloat(delta)));
144144
}
145145

146146

OnnxStack.StableDiffusion/Schedulers/KDPM2Scheduler.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int tim
116116
/// <returns></returns>
117117
/// <exception cref="System.ArgumentException">Invalid prediction_type: {SchedulerOptions.PredictionType}</exception>
118118
/// <exception cref="System.NotImplementedException">KDPM2Scheduler Thresholding currently not implemented</exception>
119-
public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
119+
public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
120120
{
121121
float sigma;
122122
float sigmaInterpol;
@@ -177,7 +177,7 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
177177
}
178178

179179
_stepIndex += 1;
180-
return sample.AddTensors(derivative.MultipleTensorByFloat(dt));
180+
return new SchedulerStepResult(sample.AddTensors(derivative.MultipleTensorByFloat(dt)));
181181
}
182182

183183

OnnxStack.StableDiffusion/Schedulers/LCMScheduler.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ public override DenseTensor<float> ScaleInput(DenseTensor<float> sample, int tim
107107
/// <param name="sample">The sample.</param>
108108
/// <param name="order">The order.</param>
109109
/// <returns></returns>
110-
public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
110+
public override SchedulerStepResult Step(DenseTensor<float> modelOutput, int timestep, DenseTensor<float> sample, int order = 4)
111111
{
112112
//# Latent Consistency Models paper https://arxiv.org/abs/2310.04378
113113

@@ -164,7 +164,7 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
164164
.AddTensors(denoised.MultipleTensorByFloat(MathF.Sqrt(alphaProdTPrev)))
165165
: denoised;
166166

167-
return prevSample;
167+
return new SchedulerStepResult(prevSample, denoised);
168168
}
169169

170170

0 commit comments

Comments
 (0)