Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 2a1d771

Browse files
committed
Add low memory modes
1 parent da3a66d commit 2a1d771

40 files changed

+351
-132
lines changed

OnnxStack.Core/Model/OnnxModelSession.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,15 @@ public async Task LoadAsync()
6767
/// Unloads the model session.
6868
/// </summary>
6969
/// <returns></returns>
70-
public Task UnloadAsync()
70+
public async Task UnloadAsync()
7171
{
72+
await Task.Yield();
7273
if (_session is not null)
7374
{
74-
_metadata = null;
7575
_session.Dispose();
76+
_metadata = null;
77+
_session = null;
7678
}
77-
return Task.CompletedTask;
7879
}
7980

8081

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
using OnnxStack.StableDiffusion.Enums;
2+
3+
namespace OnnxStack.StableDiffusion.Config
4+
{
5+
public record PipelineOptions(string Name, MemoryModeType MemoryMode);
6+
7+
}

OnnxStack.StableDiffusion/Config/StableDiffusionModelSet.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public record StableDiffusionModelSet : IOnnxModelSetConfig
1414
public int SampleSize { get; set; } = 512;
1515
public DiffuserPipelineType PipelineType { get; set; }
1616
public List<DiffuserType> Diffusers { get; set; } = new List<DiffuserType>();
17-
17+
public MemoryModeType MemoryMode { get; set; }
1818
public int DeviceId { get; set; }
1919
public int InterOpNumThreads { get; set; }
2020
public int IntraOpNumThreads { get; set; }

OnnxStack.StableDiffusion/Diffusers/DiffuserBase.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ public abstract class DiffuserBase : IDiffuser
2222
protected readonly UNetConditionModel _unet;
2323
protected readonly AutoEncoderModel _vaeDecoder;
2424
protected readonly AutoEncoderModel _vaeEncoder;
25+
protected readonly MemoryModeType _memoryMode;
2526

2627
/// <summary>
2728
/// Initializes a new instance of the <see cref="DiffuserBase"/> class.
@@ -31,12 +32,13 @@ public abstract class DiffuserBase : IDiffuser
3132
/// <param name="vaeDecoder">The vae decoder.</param>
3233
/// <param name="vaeEncoder">The vae encoder.</param>
3334
/// <param name="logger">The logger.</param>
34-
public DiffuserBase(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
35+
public DiffuserBase(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
3536
{
3637
_logger = logger;
3738
_unet = unet;
3839
_vaeDecoder = vaeDecoder;
3940
_vaeEncoder = vaeEncoder;
41+
_memoryMode = memoryMode;
4042
}
4143

4244
/// <summary>
@@ -137,10 +139,15 @@ protected virtual async Task<DenseTensor<float>> DecodeLatentsAsync(PromptOption
137139
var results = await _vaeDecoder.RunInferenceAsync(inferenceParameters);
138140
using (var imageResult = results.First())
139141
{
142+
// Unload if required
143+
if (_memoryMode != MemoryModeType.Maximum)
144+
await _vaeDecoder.UnloadAsync();
145+
140146
_logger?.LogEnd("Latents decoded", timestamp);
141147
return imageResult.ToDenseTensor();
142148
}
143149
}
150+
144151
}
145152

146153

OnnxStack.StableDiffusion/Diffusers/InstaFlow/ControlNetDiffuser.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ public class ControlNetDiffuser : InstaFlowDiffuser
3030
/// <param name="vaeDecoder">The vae decoder.</param>
3131
/// <param name="vaeEncoder">The vae encoder.</param>
3232
/// <param name="logger">The logger.</param>
33-
public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
34-
: base(unet, vaeDecoder, vaeEncoder, logger)
33+
public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
34+
: base(unet, vaeDecoder, vaeEncoder, memoryMode, logger)
3535
{
3636
_controlNet = controlNet;
3737
}
@@ -147,6 +147,13 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
147147
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
148148
}
149149

150+
// Unload if required
151+
if (_memoryMode != MemoryModeType.Maximum)
152+
{
153+
await _unet.UnloadAsync();
154+
await _controlNet.UnloadAsync();
155+
}
156+
150157
// Decode Latents
151158
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
152159
}

OnnxStack.StableDiffusion/Diffusers/InstaFlow/InstaFlowDiffuser.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ public abstract class InstaFlowDiffuser : DiffuserBase
2626
/// <param name="vaeDecoder">The vae decoder.</param>
2727
/// <param name="vaeEncoder">The vae encoder.</param>
2828
/// <param name="logger">The logger.</param>
29-
public InstaFlowDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
30-
: base(unet, vaeDecoder, vaeEncoder, logger) { }
29+
public InstaFlowDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
30+
: base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
3131

3232
/// <summary>
3333
/// Gets the type of the pipeline.
@@ -106,6 +106,10 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
106106
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
107107
}
108108

109+
// Unload if required
110+
if (_memoryMode != MemoryModeType.Maximum)
111+
await _unet.UnloadAsync();
112+
109113
// Decode Latents
110114
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
111115
}

OnnxStack.StableDiffusion/Diffusers/InstaFlow/TextDiffuser.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ public sealed class TextDiffuser : InstaFlowDiffuser
2020
/// <param name="vaeDecoder">The vae decoder.</param>
2121
/// <param name="vaeEncoder">The vae encoder.</param>
2222
/// <param name="logger">The logger.</param>
23-
public TextDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
24-
: base(unet, vaeDecoder, vaeEncoder, logger) { }
23+
public TextDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
24+
: base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
2525

2626

2727
/// <summary>

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetDiffuser.cs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ public class ControlNetDiffuser : LatentConsistencyDiffuser
2929
/// <param name="vaeDecoder">The vae decoder.</param>
3030
/// <param name="vaeEncoder">The vae encoder.</param>
3131
/// <param name="logger">The logger.</param>
32-
public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
33-
: base(unet, vaeDecoder, vaeEncoder, logger)
32+
public ControlNetDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
33+
: base(unet, vaeDecoder, vaeEncoder, memoryMode, logger)
3434
{
3535
_controlNet = controlNet;
3636
}
@@ -144,6 +144,13 @@ public override async Task<DenseTensor<float>> DiffuseAsync(PromptOptions prompt
144144
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
145145
}
146146

147+
// Unload if required
148+
if (_memoryMode != MemoryModeType.Maximum)
149+
{
150+
await _unet.UnloadAsync();
151+
await _controlNet.UnloadAsync();
152+
}
153+
147154
// Decode Latents
148155
return await DecodeLatentsAsync(promptOptions, schedulerOptions, latents);
149156
}

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ControlNetImageDiffuser.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ public sealed class ControlNetImageDiffuser : ControlNetDiffuser
2727
/// <param name="vaeDecoder">The vae decoder.</param>
2828
/// <param name="vaeEncoder">The vae encoder.</param>
2929
/// <param name="logger">The logger.</param>
30-
public ControlNetImageDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
31-
: base(controlNet, unet, vaeDecoder, vaeEncoder, logger) { }
30+
public ControlNetImageDiffuser(ControlNetModel controlNet, UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
31+
: base(controlNet, unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
3232

3333

3434
/// <summary>
@@ -73,6 +73,10 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(PromptOpti
7373
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
7474
using (var result = results.First())
7575
{
76+
// Unload if required
77+
if (_memoryMode != MemoryModeType.Maximum)
78+
await _vaeEncoder.UnloadAsync();
79+
7680
var outputResult = result.ToDenseTensor();
7781
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
7882
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ public sealed class ImageDiffuser : LatentConsistencyDiffuser
2626
/// <param name="vaeDecoder">The vae decoder.</param>
2727
/// <param name="vaeEncoder">The vae encoder.</param>
2828
/// <param name="logger">The logger.</param>
29-
public ImageDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, ILogger logger = default)
30-
: base(unet, vaeDecoder, vaeEncoder, logger) { }
29+
public ImageDiffuser(UNetConditionModel unet, AutoEncoderModel vaeDecoder, AutoEncoderModel vaeEncoder, MemoryModeType memoryMode, ILogger logger = default)
30+
: base(unet, vaeDecoder, vaeEncoder, memoryMode, logger) { }
3131

3232

3333
/// <summary>
@@ -70,6 +70,10 @@ protected override async Task<DenseTensor<float>> PrepareLatentsAsync(PromptOpti
7070
var results = await _vaeEncoder.RunInferenceAsync(inferenceParameters);
7171
using (var result = results.First())
7272
{
73+
// Unload if required
74+
if (_memoryMode != MemoryModeType.Maximum)
75+
await _vaeEncoder.UnloadAsync();
76+
7377
var outputResult = result.ToDenseTensor();
7478
var scaledSample = outputResult.MultiplyBy(_vaeEncoder.ScaleFactor);
7579
return scheduler.AddNoise(scaledSample, scheduler.CreateRandomSample(scaledSample.Dimensions), timesteps);

0 commit comments

Comments
 (0)