Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 8119c01

Browse files
committed
Process batch images one by one to lower VRAM usage
1 parent ad25a3a commit 8119c01

File tree

4 files changed

+59
-15
lines changed

4 files changed

+59
-15
lines changed

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
113113
}
114114

115115
// Decode Latents
116-
return await DecodeLatents(modelOptions, schedulerOptions, latents);
116+
return await DecodeLatents(modelOptions, promptOptions, schedulerOptions, latents);
117117
}
118118
}
119119

@@ -123,26 +123,42 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
123123
/// <param name="options">The options.</param>
124124
/// <param name="latents">The latents.</param>
125125
/// <returns></returns>
126-
protected async Task<DenseTensor<float>> DecodeLatents(IModelOptions model, SchedulerOptions options, DenseTensor<float> latents)
126+
protected async Task<DenseTensor<float>> DecodeLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, DenseTensor<float> latents)
127127
{
128128
// Scale and decode the image latents with vae.
129129
latents = latents.MultiplyBy(1.0f / model.ScaleFactor);
130130

131-
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeDecoder);
132-
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], latents));
133-
134-
// Run inference.
135-
using (var inferResult = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputParameters))
131+
var images = prompt.BatchCount > 1
132+
? latents.Split(prompt.BatchCount)
133+
: new[] { latents };
134+
var imageTensors = new List<DenseTensor<float>>();
135+
foreach (var image in images)
136136
{
137-
var resultTensor = inferResult.FirstElementAs<DenseTensor<float>>();
138-
if (await _onnxModelService.IsEnabledAsync(model, OnnxModelType.SafetyChecker))
137+
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeDecoder);
138+
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], image));
139+
140+
// Run inference.
141+
using (var inferResult = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputParameters))
139142
{
140-
// Check if image contains NSFW content,
141-
if (!await IsImageSafe(model, options, resultTensor))
142-
return resultTensor.CloneEmpty().ToDenseTensor(); //TODO: blank image?, exception?, null?
143+
var resultTensor = inferResult.FirstElementAs<DenseTensor<float>>();
144+
if (await _onnxModelService.IsEnabledAsync(model, OnnxModelType.SafetyChecker))
145+
{
146+
// Check if image contains NSFW content,
147+
if (!await IsImageSafe(model, options, resultTensor))
148+
{
149+
//TODO: blank image?, exception?, null?
150+
imageTensors.Add(resultTensor.CloneEmpty().ToDenseTensor());
151+
continue;
152+
}
153+
}
154+
155+
if (prompt.BatchCount == 1)
156+
return resultTensor.ToDenseTensor();
157+
158+
imageTensors.Add(resultTensor.ToDenseTensor());
143159
}
144-
return resultTensor.ToDenseTensor();
145160
}
161+
return imageTensors.Join();
146162
}
147163

148164

OnnxStack.StableDiffusion/Diffusers/InpaintDiffuser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
9494
}
9595

9696
// Decode Latents
97-
return await DecodeLatents(modelOptions, schedulerOptions, latents);
97+
return await DecodeLatents(modelOptions, promptOptions, schedulerOptions, latents);
9898
}
9999
}
100100

OnnxStack.StableDiffusion/Diffusers/InpaintLegacyDiffuser.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
9595
}
9696

9797
// Decode Latents
98-
return await DecodeLatents(modelOptions, schedulerOptions, latents);
98+
return await DecodeLatents(modelOptions, promptOptions, schedulerOptions, latents);
9999
}
100100
}
101101

OnnxStack.StableDiffusion/Helpers/TensorHelper.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
22
using System;
3+
using System.Collections.Generic;
34
using System.Linq;
45

56
namespace OnnxStack.StableDiffusion.Helpers
@@ -366,6 +367,33 @@ public static DenseTensor<float>[] Split(this DenseTensor<float> tensor, int cou
366367
}
367368

368369

370+
/// <summary>
371+
/// Joins the tensors across the 0 axis.
372+
/// </summary>
373+
/// <param name="tensors">The tensors.</param>
374+
/// <param name="axis">The axis.</param>
375+
/// <returns></returns>
376+
/// <exception cref="System.NotImplementedException">Only axis 0 is supported</exception>
377+
public static DenseTensor<float> Join(this IList<DenseTensor<float>> tensors, int axis = 0)
378+
{
379+
if (axis != 0)
380+
throw new NotImplementedException("Only axis 0 is supported");
381+
382+
var tensor = tensors.First();
383+
var dimensions = tensor.Dimensions.ToArray();
384+
dimensions[0] *= tensors.Count;
385+
386+
var newLength = (int)tensor.Length;
387+
var buffer = new float[newLength * tensors.Count].AsMemory();
388+
for (int i = 0; i < tensors.Count(); i++)
389+
{
390+
var start = i * newLength;
391+
tensors[i].Buffer.CopyTo(buffer[start..]);
392+
}
393+
return new DenseTensor<float>(buffer, dimensions);
394+
}
395+
396+
369397
/// <summary>
370398
/// Adds the tensors.
371399
/// </summary>

0 commit comments

Comments
 (0)