11using Microsoft . Extensions . Logging ;
2- using Microsoft . ML . OnnxRuntime ;
32using Microsoft . ML . OnnxRuntime . Tensors ;
43using OnnxStack . Core ;
54using OnnxStack . Core . Config ;
98using OnnxStack . StableDiffusion . Config ;
109using OnnxStack . StableDiffusion . Enums ;
1110using OnnxStack . StableDiffusion . Helpers ;
11+ using OnnxStack . StableDiffusion . Models ;
1212using SixLabors . ImageSharp ;
1313using System ;
1414using System . Collections . Generic ;
15+ using System . Diagnostics ;
1516using System . Linq ;
17+ using System . Threading ;
1618using System . Threading . Tasks ;
1719
1820namespace OnnxStack . StableDiffusion . Diffusers . ControlNet
@@ -35,6 +37,106 @@ public ImageDiffuser(IOnnxModelService onnxModelService, IPromptService promptSe
3537 /// </summary>
3638 public override DiffuserType DiffuserType => DiffuserType . ImageToImage ;
3739
40+ /// <summary>
41+ /// Called on each Scheduler step.
42+ /// </summary>
43+ /// <param name="modelOptions">The model options.</param>
44+ /// <param name="promptOptions">The prompt options.</param>
45+ /// <param name="schedulerOptions">The scheduler options.</param>
46+ /// <param name="promptEmbeddings">The prompt embeddings.</param>
47+ /// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
48+ /// <param name="progressCallback">The progress callback.</param>
49+ /// <param name="cancellationToken">The cancellation token.</param>
50+ /// <returns></returns>
51+ /// <exception cref="System.NotImplementedException"></exception>
52+ protected override async Task < DenseTensor < float > > SchedulerStepAsync ( StableDiffusionModelSet modelOptions , PromptOptions promptOptions , SchedulerOptions schedulerOptions , PromptEmbeddingsResult promptEmbeddings , bool performGuidance , Action < DiffusionProgress > progressCallback = null , CancellationToken cancellationToken = default )
53+ {
54+ // Get Scheduler
55+ using ( var scheduler = GetScheduler ( schedulerOptions ) )
56+ {
57+ // Get timesteps
58+ var timesteps = GetTimesteps ( schedulerOptions , scheduler ) ;
59+
60+ // Create latent sample
61+ var latents = await PrepareLatentsAsync ( modelOptions , promptOptions , schedulerOptions , scheduler , timesteps ) ;
62+
63+ // Get Model metadata
64+ var metadata = _onnxModelService . GetModelMetadata ( modelOptions , OnnxModelType . Unet ) ;
65+
66+ // Get Model metadata
67+ var controlNetMetadata = _onnxModelService . GetModelMetadata ( modelOptions , OnnxModelType . Control ) ;
68+
69+ // TODO: do we need to pre-process?
70+ var controlImage = promptOptions . InputImage . ToDenseTensor ( new [ ] { 1 , 3 , schedulerOptions . Height , schedulerOptions . Width } ) ;
71+
72+ // Loop though the timesteps
73+ var step = 0 ;
74+ foreach ( var timestep in timesteps )
75+ {
76+ step ++ ;
77+ var stepTime = Stopwatch . GetTimestamp ( ) ;
78+ cancellationToken . ThrowIfCancellationRequested ( ) ;
79+
80+ // Create input tensor.
81+ var inputLatent = performGuidance ? latents . Repeat ( 2 ) : latents ;
82+ var inputTensor = scheduler . ScaleInput ( inputLatent , timestep ) ;
83+ var timestepTensor = CreateTimestepTensor ( timestep ) ;
84+
85+ var outputChannels = performGuidance ? 2 : 1 ;
86+ var outputDimension = schedulerOptions . GetScaledDimension ( outputChannels ) ;
87+ using ( var inferenceParameters = new OnnxInferenceParameters ( metadata ) )
88+ {
89+ inferenceParameters . AddInputTensor ( inputTensor ) ;
90+ inferenceParameters . AddInputTensor ( timestepTensor ) ;
91+ inferenceParameters . AddInputTensor ( promptEmbeddings . PromptEmbeds ) ;
92+
93+ // ControlNet
94+ using ( var controlNetParameters = new OnnxInferenceParameters ( controlNetMetadata ) )
95+ {
96+ controlNetParameters . AddInputTensor ( inputTensor ) ;
97+ controlNetParameters . AddInputTensor ( timestepTensor ) ;
98+ controlNetParameters . AddInputTensor ( promptEmbeddings . PromptEmbeds ) ;
99+ controlNetParameters . AddInputTensor ( controlImage ) ;
100+ foreach ( var item in controlNetMetadata . Outputs )
101+ controlNetParameters . AddOutputBuffer ( ) ;
102+
103+ var controlNetResults = _onnxModelService . RunInference ( modelOptions , OnnxModelType . Control , controlNetParameters ) ;
104+ if ( controlNetResults . IsNullOrEmpty ( ) )
105+ throw new Exception ( "Control model produced no result" ) ;
106+
107+ // Add ControlNet outputs to Unet input
108+ foreach ( var item in controlNetResults )
109+ inferenceParameters . AddInputTensor ( item . ToDenseTensor ( ) ) ;
110+ }
111+
112+
113+ // Add output buffer
114+ inferenceParameters . AddOutputBuffer ( outputDimension ) ;
115+
116+ // Unet
117+ var results = await _onnxModelService . RunInferenceAsync ( modelOptions , OnnxModelType . Unet , inferenceParameters ) ;
118+ using ( var result = results . First ( ) )
119+ {
120+ var noisePred = result . ToDenseTensor ( ) ;
121+
122+ // Perform guidance
123+ if ( performGuidance )
124+ noisePred = PerformGuidance ( noisePred , schedulerOptions . GuidanceScale ) ;
125+
126+ // Scheduler Step
127+ latents = scheduler . Step ( noisePred , timestep , latents ) . Result ;
128+ }
129+ }
130+
131+ ReportProgress ( progressCallback , step , timesteps . Count , latents ) ;
132+ _logger ? . LogEnd ( $ "Step { step } /{ timesteps . Count } ", stepTime ) ;
133+ }
134+
135+ // Decode Latents
136+ return await DecodeLatentsAsync ( modelOptions , promptOptions , schedulerOptions , latents ) ;
137+ }
138+ }
139+
38140
39141 /// <summary>
40142 /// Gets the timesteps.
0 commit comments