Skip to content

Commit 0d76d49

Browse files
committed
Fix LowMemory mode
1 parent 26bfdc1 commit 0d76d49

File tree

6 files changed

+37
-24
lines changed

6 files changed

+37
-24
lines changed

TensorStack.Common/Tensor/Tensor.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ public Tensor(ReadOnlySpan<int> dimensions, T fillWith)
7676
/// </summary>
7777
public ReadOnlySpan<int> Dimensions => _dimensions;
7878

79+
/// <summary>
80+
/// Gets the rank.
81+
/// </summary>
82+
public int Rank => _dimensions.Length;
83+
7984

8085
/// <summary>
8186
/// Gets or sets the <see cref="T"/> with the specified indices.

TensorStack.StableDiffusion/Config/AutoEncoderConfig.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ public record AutoEncoderModelConfig : ModelConfig
88
{
99
public int Scale { get; set; } = 8;
1010
public float ScaleFactor { get; set; }
11-
public float ShiftFactor { get; set; } = 1;
11+
public float ShiftFactor { get; set; }
1212
public int InChannels { get; set; } = 3;
1313
public int OutChannels { get; set; } = 3;
1414
public int LatentChannels { get; set; } = 4;

TensorStack.StableDiffusion/Pipelines/IPipelineOptions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ public interface IPipelineOptions : IRunOptions
1616
string Prompt { get; set; }
1717
string NegativePrompt { get; set; }
1818
float GuidanceScale { get; set; }
19+
float GuidanceScale2 { get; set; }
1920
public SchedulerType Scheduler { get; set; }
2021

2122
float Strength { get; set; }

TensorStack.StableDiffusion/Pipelines/StableCascade/StableCascadePipeline.cs

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,17 @@ private async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, Can
191191
var hiddenStateIndex = 1 + options.ClipSkip;
192192
var promptEmbeddings = await EncodePromptAsync(promptTokens, maxPromptTokenCount, hiddenStateIndex, cancellationToken);
193193
var negativePromptEmbeddings = await EncodePromptAsync(negativePromptTokens, maxPromptTokenCount, hiddenStateIndex, cancellationToken);
194-
if (options.IsLowMemoryTextEncoderEnabled)
194+
if (options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled)
195195
await TextEncoder.UnloadAsync();
196196

197-
var textEmbeds = promptEmbeddings.TextEmbeds.Reshape([1, .. promptEmbeddings.TextEmbeds.Dimensions]);
198-
var negativeTextEmbeds = negativePromptEmbeddings.TextEmbeds.Reshape([1, .. negativePromptEmbeddings.TextEmbeds.Dimensions]);
197+
var textEmbeds = promptEmbeddings.TextEmbeds.Rank == 3
198+
? promptEmbeddings.TextEmbeds
199+
: promptEmbeddings.TextEmbeds.Reshape([1, .. promptEmbeddings.TextEmbeds.Dimensions]);
200+
201+
var negativeTextEmbeds = negativePromptEmbeddings.TextEmbeds.Rank == 3
202+
? negativePromptEmbeddings.TextEmbeds
203+
: negativePromptEmbeddings.TextEmbeds.Reshape([1, .. negativePromptEmbeddings.TextEmbeds.Dimensions]);
204+
199205
return new PromptResult(promptEmbeddings.HiddenStates, textEmbeds, negativePromptEmbeddings.HiddenStates, negativeTextEmbeds);
200206
}
201207

@@ -297,7 +303,7 @@ private async Task<Tensor<float>> RunPriorAsync(GenerateOptions options, Tensor<
297303
}
298304

299305
// Unload if required
300-
if (options.IsLowMemoryComputeEnabled)
306+
if (options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled)
301307
await PriorUnet.UnloadAsync();
302308

303309
Logger.LogEnd(LogLevel.Debug, timestamp, "[RunPriorAsync] Prior Inference Complete");
@@ -374,7 +380,7 @@ private async Task<Tensor<float>> RunDecoderAsync(GenerateOptions options, Tenso
374380
}
375381

376382
// Unload if required
377-
if (options.IsLowMemoryComputeEnabled)
383+
if (options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled)
378384
await DecoderUnet.UnloadAsync();
379385

380386
Logger.LogEnd(LogLevel.Debug, timestamp, "[RunDecoderAsync] Decoder Inference Complete");
@@ -431,7 +437,7 @@ private Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, ImageTe
431437
private async Task<ImageTensor> DecodeLatentsAsync(IPipelineOptions options, Tensor<float> latents, CancellationToken cancellationToken = default)
432438
{
433439
var decoderResult = await ImageDecoder.RunAsync(latents, cancellationToken: cancellationToken);
434-
if (options.IsLowMemoryDecoderEnabled)
440+
if (options.IsLowMemoryEnabled || options.IsLowMemoryDecoderEnabled)
435441
await ImageDecoder.UnloadAsync();
436442

437443
return decoderResult.AsImageTensor();
@@ -445,15 +451,15 @@ private async Task<ImageTensor> DecodeLatentsAsync(IPipelineOptions options, Ten
445451
protected override async Task CheckPipelineState(IPipelineOptions options)
446452
{
447453
// Check LowMemory status
448-
if (options.IsLowMemoryTextEncoderEnabled && TextEncoder.IsLoaded())
454+
if ((options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled) && TextEncoder.IsLoaded())
449455
await TextEncoder.UnloadAsync();
450-
if (options.IsLowMemoryComputeEnabled && PriorUnet.IsLoaded())
456+
if ((options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled) && PriorUnet.IsLoaded())
451457
await PriorUnet.UnloadAsync();
452-
if (options.IsLowMemoryComputeEnabled && DecoderUnet.IsLoaded())
458+
if ((options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled) && DecoderUnet.IsLoaded())
453459
await DecoderUnet.UnloadAsync();
454-
if (options.IsLowMemoryEncoderEnabled && ImageDecoder.IsLoaded())
460+
if ((options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled) && ImageDecoder.IsLoaded())
455461
await ImageDecoder.UnloadAsync();
456-
if (options.IsLowMemoryDecoderEnabled && ImageEncoder.IsLoaded())
462+
if ((options.IsLowMemoryEnabled || options.IsLowMemoryDecoderEnabled) && ImageEncoder.IsLoaded())
457463
await ImageEncoder.UnloadAsync();
458464
}
459465

TensorStack.StableDiffusion/Pipelines/StableDiffusion3/StableDiffusion3Pipeline.cs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -203,20 +203,20 @@ private async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, Can
203203
// TextEncoder
204204
var prompt1Embeddings = await EncodePromptAsync(promptTokens, maxPromptTokenCount, cancellationToken);
205205
var negativePrompt1Embeddings = await EncodePromptAsync(negativePromptTokens, maxPromptTokenCount, cancellationToken);
206-
if (options.IsLowMemoryTextEncoderEnabled)
206+
if (options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled)
207207
await TextEncoder.UnloadAsync();
208208

209209
// TextEncoder2
210210
var hiddenStateIndex = 2 + options.ClipSkip;
211211
var prompt2Embeddings = await EncodePrompt2Async(prompt2Tokens, maxPromptTokenCount, hiddenStateIndex, cancellationToken);
212212
var negativePrompt2Embeddings = await EncodePrompt2Async(negativePrompt2Tokens, maxPromptTokenCount, hiddenStateIndex, cancellationToken);
213-
if (options.IsLowMemoryTextEncoderEnabled)
213+
if (options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled)
214214
await TextEncoder2.UnloadAsync();
215215

216216
// TextEncoder3
217217
var prompt3Embeddings = await EncodePrompt3Async(prompt3Tokens, cancellationToken);
218218
var negativePrompt3Embeddings = await EncodePrompt3Async(negativePrompt3Tokens, cancellationToken);
219-
if (options.IsLowMemoryTextEncoderEnabled)
219+
if (options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled)
220220
await TextEncoder3.UnloadAsync();
221221

222222
// Positive Prompt
@@ -617,19 +617,19 @@ protected override async Task CheckPipelineState(IPipelineOptions options)
617617
await Transformer.UnloadControlNetAsync();
618618

619619
// Check LowMemory status
620-
if (options.IsLowMemoryTextEncoderEnabled && TextEncoder.IsLoaded())
620+
if ((options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled)&& TextEncoder.IsLoaded())
621621
await TextEncoder.UnloadAsync();
622-
if (options.IsLowMemoryComputeEnabled && Transformer.IsLoaded())
622+
if ((options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled) && Transformer.IsLoaded())
623623
await Transformer.UnloadAsync();
624-
if (options.IsLowMemoryComputeEnabled && Transformer.IsControlNetLoaded())
624+
if ((options.IsLowMemoryEnabled || options.IsLowMemoryComputeEnabled)&& Transformer.IsControlNetLoaded())
625625
await Transformer.UnloadControlNetAsync();
626-
if (options.IsLowMemoryTextEncoderEnabled && TextEncoder3.IsLoaded())
626+
if ((options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled) && TextEncoder3.IsLoaded())
627627
await TextEncoder3.UnloadAsync();
628-
if (options.IsLowMemoryTextEncoderEnabled && TextEncoder2.IsLoaded())
628+
if ((options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled) && TextEncoder2.IsLoaded())
629629
await TextEncoder2.UnloadAsync();
630-
if (options.IsLowMemoryEncoderEnabled && AutoEncoder.IsEncoderLoaded())
630+
if ((options.IsLowMemoryEnabled || options.IsLowMemoryEncoderEnabled)&& AutoEncoder.IsEncoderLoaded())
631631
await AutoEncoder.EncoderUnloadAsync();
632-
if (options.IsLowMemoryDecoderEnabled && AutoEncoder.IsDecoderLoaded())
632+
if ((options.IsLowMemoryEnabled || options.IsLowMemoryDecoderEnabled) && AutoEncoder.IsDecoderLoaded())
633633
await AutoEncoder.DecoderUnloadAsync();
634634
}
635635

@@ -664,6 +664,7 @@ protected override GenerateOptions ConfigureDefaultOptions()
664664
return options with
665665
{
666666
Steps = 4,
667+
Shift = 3f,
667668
Width = 1024,
668669
Height = 1024,
669670
GuidanceScale = 0,

TensorStack.StableDiffusion/Pipelines/StableDiffusionXL/StableDiffusionXLPipeline.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,14 @@ private async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, Can
180180
// TextEncoder
181181
var prompt1Embeddings = await EncodePromptAsync(promptTokens, maxPromptTokenCount, cancellationToken);
182182
var negativePrompt1Embeddings = await EncodePromptAsync(negativePromptTokens, maxPromptTokenCount, cancellationToken);
183-
if (options.IsLowMemoryTextEncoderEnabled)
183+
if (options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled)
184184
await TextEncoder.UnloadAsync();
185185

186186
// TextEncoder2
187187
var hiddenStateIndex = 2 + options.ClipSkip;
188188
var prompt2Embeddings = await EncodePrompt2Async(prompt2Tokens, maxPromptTokenCount, hiddenStateIndex, cancellationToken);
189189
var negativePrompt2Embeddings = await EncodePrompt2Async(negativePrompt2Tokens, maxPromptTokenCount, hiddenStateIndex, cancellationToken);
190-
if (options.IsLowMemoryTextEncoderEnabled)
190+
if (options.IsLowMemoryEnabled || options.IsLowMemoryTextEncoderEnabled)
191191
await TextEncoder2.UnloadAsync();
192192

193193
// Prompt embeds

0 commit comments

Comments
 (0)