@@ -191,11 +191,17 @@ private async Task<PromptResult> CreatePromptAsync(IPipelineOptions options, Can
191191 var hiddenStateIndex = 1 + options . ClipSkip ;
192192 var promptEmbeddings = await EncodePromptAsync ( promptTokens , maxPromptTokenCount , hiddenStateIndex , cancellationToken ) ;
193193 var negativePromptEmbeddings = await EncodePromptAsync ( negativePromptTokens , maxPromptTokenCount , hiddenStateIndex , cancellationToken ) ;
194- if ( options . IsLowMemoryTextEncoderEnabled )
194+ if ( options . IsLowMemoryEnabled || options . IsLowMemoryTextEncoderEnabled )
195195 await TextEncoder . UnloadAsync ( ) ;
196196
197- var textEmbeds = promptEmbeddings . TextEmbeds . Reshape ( [ 1 , .. promptEmbeddings . TextEmbeds . Dimensions ] ) ;
198- var negativeTextEmbeds = negativePromptEmbeddings . TextEmbeds . Reshape ( [ 1 , .. negativePromptEmbeddings . TextEmbeds . Dimensions ] ) ;
197+ var textEmbeds = promptEmbeddings . TextEmbeds . Rank == 3
198+ ? promptEmbeddings . TextEmbeds
199+ : promptEmbeddings . TextEmbeds . Reshape ( [ 1 , .. promptEmbeddings . TextEmbeds . Dimensions ] ) ;
200+
201+ var negativeTextEmbeds = negativePromptEmbeddings . TextEmbeds . Rank == 3
202+ ? negativePromptEmbeddings . TextEmbeds
203+ : negativePromptEmbeddings . TextEmbeds . Reshape ( [ 1 , .. negativePromptEmbeddings . TextEmbeds . Dimensions ] ) ;
204+
199205 return new PromptResult ( promptEmbeddings . HiddenStates , textEmbeds , negativePromptEmbeddings . HiddenStates , negativeTextEmbeds ) ;
200206 }
201207
@@ -297,7 +303,7 @@ private async Task<Tensor<float>> RunPriorAsync(GenerateOptions options, Tensor<
297303 }
298304
299305 // Unload if required
300- if ( options . IsLowMemoryComputeEnabled )
306+ if ( options . IsLowMemoryEnabled || options . IsLowMemoryComputeEnabled )
301307 await PriorUnet . UnloadAsync ( ) ;
302308
303309 Logger . LogEnd ( LogLevel . Debug , timestamp , "[RunPriorAsync] Prior Inference Complete" ) ;
@@ -374,7 +380,7 @@ private async Task<Tensor<float>> RunDecoderAsync(GenerateOptions options, Tenso
374380 }
375381
376382 // Unload if required
377- if ( options . IsLowMemoryComputeEnabled )
383+ if ( options . IsLowMemoryEnabled || options . IsLowMemoryComputeEnabled )
378384 await DecoderUnet . UnloadAsync ( ) ;
379385
380386 Logger . LogEnd ( LogLevel . Debug , timestamp , "[RunDecoderAsync] Decoder Inference Complete" ) ;
@@ -431,7 +437,7 @@ private Task<Tensor<float>> EncodeLatentsAsync(IPipelineOptions options, ImageTe
431437 private async Task < ImageTensor > DecodeLatentsAsync ( IPipelineOptions options , Tensor < float > latents , CancellationToken cancellationToken = default )
432438 {
433439 var decoderResult = await ImageDecoder . RunAsync ( latents , cancellationToken : cancellationToken ) ;
434- if ( options . IsLowMemoryDecoderEnabled )
440+ if ( options . IsLowMemoryEnabled || options . IsLowMemoryDecoderEnabled )
435441 await ImageDecoder . UnloadAsync ( ) ;
436442
437443 return decoderResult . AsImageTensor ( ) ;
@@ -445,15 +451,15 @@ private async Task<ImageTensor> DecodeLatentsAsync(IPipelineOptions options, Ten
445451 protected override async Task CheckPipelineState ( IPipelineOptions options )
446452 {
447453 // Check LowMemory status
448- if ( options . IsLowMemoryTextEncoderEnabled && TextEncoder . IsLoaded ( ) )
454+ if ( ( options . IsLowMemoryEnabled || options . IsLowMemoryTextEncoderEnabled ) && TextEncoder . IsLoaded ( ) )
449455 await TextEncoder . UnloadAsync ( ) ;
450- if ( options . IsLowMemoryComputeEnabled && PriorUnet . IsLoaded ( ) )
456+ if ( ( options . IsLowMemoryEnabled || options . IsLowMemoryComputeEnabled ) && PriorUnet . IsLoaded ( ) )
451457 await PriorUnet . UnloadAsync ( ) ;
452- if ( options . IsLowMemoryComputeEnabled && DecoderUnet . IsLoaded ( ) )
458+ if ( ( options . IsLowMemoryEnabled || options . IsLowMemoryComputeEnabled ) && DecoderUnet . IsLoaded ( ) )
453459 await DecoderUnet . UnloadAsync ( ) ;
454- if ( options . IsLowMemoryEncoderEnabled && ImageDecoder . IsLoaded ( ) )
460+ if ( ( options . IsLowMemoryEnabled || options . IsLowMemoryEncoderEnabled ) && ImageDecoder . IsLoaded ( ) )
455461 await ImageDecoder . UnloadAsync ( ) ;
456- if ( options . IsLowMemoryDecoderEnabled && ImageEncoder . IsLoaded ( ) )
462+ if ( ( options . IsLowMemoryEnabled || options . IsLowMemoryDecoderEnabled ) && ImageEncoder . IsLoaded ( ) )
457463 await ImageEncoder . UnloadAsync ( ) ;
458464 }
459465
0 commit comments