Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 8d5575a

Browse files
committed
Normalize VQGAN output
1 parent 995c9eb commit 8d5575a

File tree

5 files changed

+40
-8
lines changed

5 files changed

+40
-8
lines changed

OnnxStack.Console/Examples/StableCascadeExample.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public async Task RunAsync()
4949
{
5050
SchedulerType = StableDiffusion.Enums.SchedulerType.DDPM,
5151
GuidanceScale =4f,
52-
InferenceSteps = 60,
52+
InferenceSteps = 20,
5353
Width = 1024,
5454
Height = 1024
5555
};
@@ -60,7 +60,7 @@ public async Task RunAsync()
6060
// Run pipeline
6161
var result = await pipeline.RunAsync(promptOptions, schedulerOptions, progressCallback: OutputHelpers.ProgressCallback);
6262

63-
var image = new OnnxImage(result, ImageNormalizeType.ZeroToOne);
63+
var image = new OnnxImage(result);
6464

6565
// Save Image File
6666
await image.SaveAsync(Path.Combine(_outputDirectory, $"output.png"));

OnnxStack.Core/Extensions/Extensions.cs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ public static long[] ToLong(this int[] array)
251251
/// Normalize the data using Min-Max scaling to ensure all values are in the range [0, 1].
252252
/// </summary>
253253
/// <param name="values">The values.</param>
254-
public static void NormalizeMinMax(this Span<float> values)
254+
public static Span<float> NormalizeZeroToOne(this Span<float> values)
255255
{
256256
float min = float.PositiveInfinity, max = float.NegativeInfinity;
257257
foreach (var val in values)
@@ -265,6 +265,23 @@ public static void NormalizeMinMax(this Span<float> values)
265265
{
266266
values[i] = (values[i] - min) / range;
267267
}
268+
return values;
269+
}
270+
271+
272+
public static Span<float> NormalizeOneToOne(this Span<float> values)
273+
{
274+
float max = values[0];
275+
foreach (var val in values)
276+
{
277+
if (max < val) max = val;
278+
}
279+
280+
for (var i = 0; i < values.Length; i++)
281+
{
282+
values[i] = (values[i] * 2) - 1;
283+
}
284+
return values;
268285
}
269286
}
270287
}

OnnxStack.Core/Extensions/OrtValueExtensions.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,18 @@ public static DenseTensor<float> ToDenseTensor(this OrtValue ortValue, ReadOnlyS
127127
}
128128

129129

130+
/// <summary>
131+
/// Converts Span<float> to DenseTensor<float>.
132+
/// </summary>
133+
/// <param name="ortSpanValue">The ort span value.</param>
134+
/// <param name="dimensions">The dimensions.</param>
135+
/// <returns></returns>
136+
public static DenseTensor<float> ToDenseTensor(this Span<float> ortSpanValue, ReadOnlySpan<int> dimensions)
137+
{
138+
return new DenseTensor<float>(ortSpanValue.ToArray(), dimensions);
139+
}
140+
141+
130142
/// <summary>
131143
/// Converts to array.
132144
/// TODO: Optimization

OnnxStack.Core/Extensions/TensorExtension.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ public static DenseTensor<float> Repeat(this DenseTensor<float> tensor1, int cou
287287
/// <param name="tensor">The tensor.</param>
288288
public static void NormalizeMinMax(this DenseTensor<float> tensor)
289289
{
290-
tensor.Buffer.Span.NormalizeMinMax();
290+
tensor.Buffer.Span.NormalizeZeroToOne();
291291
}
292292

293293

OnnxStack.StableDiffusion/Diffusers/StableCascade/StableCascadeDiffuser.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,21 +196,24 @@ protected override async Task<DenseTensor<float>> DecodeLatentsAsync(PromptOptio
196196
{
197197
latents = latents.MultiplyBy(_vaeDecoder.ScaleFactor);
198198

199-
var outputDim = new[] { 1, 4, 256, 256 };
199+
var outputDim = new[] { 1, 3, options.Height, options.Width };
200200
var metadata = await _vaeDecoder.GetMetadataAsync();
201201
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
202202
{
203203
inferenceParameters.AddInputTensor(latents);
204-
inferenceParameters.AddOutputBuffer();
204+
inferenceParameters.AddOutputBuffer(outputDim);
205205

206-
var results = _vaeDecoder.RunInference(inferenceParameters);
206+
var results = await _vaeDecoder.RunInferenceAsync(inferenceParameters);
207207
using (var imageResult = results.First())
208208
{
209209
// Unload if required
210210
if (_memoryMode == MemoryModeType.Minimum)
211211
await _vaeDecoder.UnloadAsync();
212212

213-
return imageResult.ToDenseTensor();
213+
return imageResult
214+
.GetTensorMutableDataAsSpan<float>()
215+
.NormalizeOneToOne()
216+
.ToDenseTensor(outputDim);
214217
}
215218
}
216219
}

0 commit comments

Comments
 (0)