@@ -34,6 +34,7 @@ public class StableDiffusionPipeline : PipelineBase
3434 protected List < DiffuserType > _supportedDiffusers ;
3535 protected IReadOnlyList < SchedulerType > _supportedSchedulers ;
3636 protected SchedulerOptions _defaultSchedulerOptions ;
37+ private UnetModeType _currentUnetMode ;
3738
3839 protected sealed record BatchResultInternal ( SchedulerOptions SchedulerOptions , List < DenseTensor < float > > Result ) ;
3940
@@ -110,6 +111,11 @@ public StableDiffusionPipeline(PipelineOptions pipelineOptions, TokenizerModel t
110111 /// </summary>
111112 public override SchedulerOptions DefaultSchedulerOptions => _defaultSchedulerOptions ;
112113
114+ /// <summary>
115+ /// Gets the current unet mode.
116+ /// </summary>
117+ public override UnetModeType CurrentUnetMode => _currentUnetMode ;
118+
113119 /// <summary>
114120 /// Gets the unet.
115121 /// </summary>
@@ -144,22 +150,29 @@ public StableDiffusionPipeline(PipelineOptions pipelineOptions, TokenizerModel t
144150 /// <summary>
145151 /// Loads the pipeline.
146152 /// </summary>
147- public override Task LoadAsync ( bool controlNet = false )
153+ public override Task LoadAsync ( UnetModeType unetMode = UnetModeType . Default )
148154 {
155+ _currentUnetMode = unetMode ;
149156 if ( _pipelineOptions . MemoryMode == MemoryModeType . Minimum )
150157 return Task . CompletedTask ;
151158
152- // Preload all models into VRAM
153- return Task . WhenAll
159+ var unetModels = Task . CompletedTask ;
160+ if ( _currentUnetMode == UnetModeType . Default )
161+ unetModels = Task . WhenAll ( _unet . LoadAsync ( ) , _controlNetUnet ? . UnloadAsync ( ) ?? Task . CompletedTask ) ;
162+ if ( _currentUnetMode == UnetModeType . ControlNet )
163+ unetModels = Task . WhenAll ( _controlNetUnet . LoadAsync ( ) , _unet . UnloadAsync ( ) ) ;
164+ if ( _currentUnetMode == UnetModeType . Both )
165+ unetModels = Task . WhenAll ( _unet . LoadAsync ( ) , _controlNetUnet ? . LoadAsync ( ) ?? Task . CompletedTask ) ;
166+
167+ var subModels = Task . WhenAll
154168 (
155- controlNet
156- ? _controlNetUnet . LoadAsync ( )
157- : _unet . LoadAsync ( ) ,
158- _tokenizer . LoadAsync ( ) ,
169+ _tokenizer . LoadAsync ( ) ,
159170 _textEncoder . LoadAsync ( ) ,
160171 _vaeDecoder . LoadAsync ( ) ,
161172 _vaeEncoder . LoadAsync ( )
162173 ) ;
174+
175+ return Task . WhenAll ( unetModels , subModels ) ;
163176 }
164177
165178
@@ -695,4 +708,5 @@ public static StableDiffusionPipeline CreatePipeline(string modelFolder, ModelTy
695708 return CreatePipeline ( ModelFactory . CreateModelSet ( modelFolder , DiffuserPipelineType . StableDiffusion , modelType , deviceId , executionProvider , memoryMode ) , logger ) ;
696709 }
697710 }
711+
698712}
0 commit comments